0
0
PyTorchml~20 mins

Checkpoint with optimizer state in PyTorch - ML Experiment: Train & Evaluate

Choose your learning style9 modes available
Experiment - Checkpoint with optimizer state
Problem:You are training a neural network using PyTorch. You want to save your model's progress so you can resume training later without losing optimizer state.
Current Metrics:Training accuracy: 85%, Validation accuracy: 82%, Loss: 0.45
Issue:Currently, you save only the model weights. When resuming training, optimizer state is lost, causing slower convergence and unstable training.
Your Task
Modify the training code to save and load both the model weights and optimizer state so training can resume seamlessly.
Use PyTorch's torch.save and torch.load functions.
Do not change the model architecture or optimizer type.
Ensure the checkpoint includes epoch number, model state_dict, and optimizer state_dict.
Hint 1
Hint 2
Hint 3
Solution
PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Simple model definition
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(10, 2)
    def forward(self, x):
        return self.fc(x)

# Create dummy dataset
X = torch.randn(100, 10)
y = torch.randint(0, 2, (100,))
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=10)

# Initialize model and optimizer
model = SimpleNet()
optimizer = optim.SGD(model.parameters(), lr=0.1)
criterion = nn.CrossEntropyLoss()

# Function to save checkpoint
def save_checkpoint(epoch, model, optimizer, path):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, path)

# Function to load checkpoint
def load_checkpoint(model, optimizer, path):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return checkpoint['epoch']

# Training loop with checkpoint saving
start_epoch = 0
num_epochs = 5
checkpoint_path = 'checkpoint.pth'

# Uncomment to load checkpoint if exists
# start_epoch = load_checkpoint(model, optimizer, checkpoint_path)

for epoch in range(start_epoch, num_epochs):
    model.train()
    for inputs, labels in loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    save_checkpoint(epoch + 1, model, optimizer, checkpoint_path)

print(f'Training completed up to epoch {num_epochs}')
Added functions to save and load checkpoints including model and optimizer states.
Modified training loop to save checkpoint after each epoch.
Included loading checkpoint code (commented) to resume training from saved state.
Results Interpretation

Before: Only model weights saved. Resuming training loses optimizer state, causing slower convergence.

After: Both model and optimizer states saved and loaded. Training resumes smoothly with stable loss and accuracy.

Saving optimizer state along with model weights is essential for resuming training effectively in PyTorch. It preserves momentum and learning rate schedules, preventing training instability.
Bonus Experiment
Try adding learning rate scheduler state to the checkpoint and restore it when loading.
💡 Hint
Save the scheduler's state_dict in the checkpoint dictionary and load it similarly to the optimizer.