Bird
Raised Fist0
PyTorchml~20 mins

Checkpoint with optimizer state in PyTorch - ML Experiment: Train & Evaluate

Choose your learning style10 modes available

Start learning this pattern below

Jump into concepts and practice - no test required

or
Recommended
Test this pattern10 questions across easy, medium, and hard to know if this pattern is strong
Experiment - Checkpoint with optimizer state
Problem:You are training a neural network using PyTorch. You want to save your model's progress so you can resume training later without losing optimizer state.
Current Metrics:Training accuracy: 85%, Validation accuracy: 82%, Loss: 0.45
Issue:Currently, you save only the model weights. When resuming training, optimizer state is lost, causing slower convergence and unstable training.
Your Task
Modify the training code to save and load both the model weights and optimizer state so training can resume seamlessly.
Use PyTorch's torch.save and torch.load functions.
Do not change the model architecture or optimizer type.
Ensure the checkpoint includes epoch number, model state_dict, and optimizer state_dict.
Hint 1
Hint 2
Hint 3
Solution
PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Simple model definition
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(10, 2)
    def forward(self, x):
        return self.fc(x)

# Create dummy dataset
X = torch.randn(100, 10)
y = torch.randint(0, 2, (100,))
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=10)

# Initialize model and optimizer
model = SimpleNet()
optimizer = optim.SGD(model.parameters(), lr=0.1)
criterion = nn.CrossEntropyLoss()

# Function to save checkpoint
def save_checkpoint(epoch, model, optimizer, path):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, path)

# Function to load checkpoint
def load_checkpoint(model, optimizer, path):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return checkpoint['epoch']

# Training loop with checkpoint saving
start_epoch = 0
num_epochs = 5
checkpoint_path = 'checkpoint.pth'

# Uncomment to load checkpoint if exists
# start_epoch = load_checkpoint(model, optimizer, checkpoint_path)

for epoch in range(start_epoch, num_epochs):
    model.train()
    for inputs, labels in loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    save_checkpoint(epoch + 1, model, optimizer, checkpoint_path)

print(f'Training completed up to epoch {num_epochs}')
Added functions to save and load checkpoints including model and optimizer states.
Modified training loop to save checkpoint after each epoch.
Included loading checkpoint code (commented) to resume training from saved state.
Results Interpretation

Before: Only model weights saved. Resuming training loses optimizer state, causing slower convergence.

After: Both model and optimizer states saved and loaded. Training resumes smoothly with stable loss and accuracy.

Saving optimizer state along with model weights is essential for resuming training effectively in PyTorch. It preserves momentum and learning rate schedules, preventing training instability.
Bonus Experiment
Try adding learning rate scheduler state to the checkpoint and restore it when loading.
💡 Hint
Save the scheduler's state_dict in the checkpoint dictionary and load it similarly to the optimizer.

Practice

(1/5)
1. What is the main reason to save the optimizer state along with the model in a PyTorch checkpoint?
easy
A. To speed up the model's inference time
B. To reduce the model size on disk
C. To resume training with the same learning rate and momentum settings
D. To convert the model to a different format

Solution

  1. Step 1: Understand what optimizer state contains

    The optimizer state includes parameters like learning rate, momentum, and other variables that affect training progress.
  2. Step 2: Reason why saving optimizer state is important

    Saving the optimizer state allows training to resume exactly where it left off, preserving these settings.
  3. Final Answer:

    To resume training with the same learning rate and momentum settings -> Option C
  4. Quick Check:

    Optimizer state saves training settings = C [OK]
Hint: Optimizer state saves training progress settings [OK]
Common Mistakes:
  • Thinking optimizer state reduces model size
  • Confusing optimizer state with model weights
  • Believing optimizer state affects inference speed
2. Which of the following is the correct way to save a checkpoint including model and optimizer states in PyTorch?
easy
A. torch.save(model, 'checkpoint.pth')
B. torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict()}, 'checkpoint.pth')
C. torch.save(optimizer, 'checkpoint.pth')
D. torch.save({'model': model, 'optimizer': optimizer}, 'checkpoint.pth')

Solution

  1. Step 1: Identify correct saving method for states

    PyTorch recommends saving state_dict() of model and optimizer for checkpoints.
  2. Step 2: Check each option

    torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict()}, 'checkpoint.pth') saves state_dict() of both model and optimizer in a dictionary, which is correct.
  3. Final Answer:

    torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict()}, 'checkpoint.pth') -> Option B
  4. Quick Check:

    Save state_dict() for model and optimizer = B [OK]
Hint: Save state_dict() of model and optimizer in dict [OK]
Common Mistakes:
  • Saving full model object instead of state_dict
  • Saving optimizer object directly
  • Not saving optimizer state at all
3. Given this code snippet, what will be printed?
import torch
import torch.nn as nn
import torch.optim as optim

model = nn.Linear(2, 1)
optimizer = optim.SGD(model.parameters(), lr=0.1)

# Save checkpoint
checkpoint = {'model': model.state_dict(), 'optimizer': optimizer.state_dict()}
torch.save(checkpoint, 'cp.pth')

# Load checkpoint
loaded = torch.load('cp.pth')
optimizer.load_state_dict(loaded['optimizer'])
print(optimizer.param_groups[0]['lr'])
medium
A. 0.1
B. 0.01
C. 1.0
D. Error: optimizer state not loaded

Solution

  1. Step 1: Understand optimizer initialization

    Optimizer is created with learning rate 0.1 and saved in checkpoint.
  2. Step 2: Loading optimizer state restores learning rate

    Loading optimizer state_dict sets learning rate back to 0.1.
  3. Final Answer:

    0.1 -> Option A
  4. Quick Check:

    Loaded optimizer lr = 0.1 [OK]
Hint: Loaded optimizer keeps saved learning rate [OK]
Common Mistakes:
  • Assuming learning rate resets to default
  • Forgetting to load optimizer state
  • Confusing model and optimizer states
4. You saved a checkpoint with model and optimizer states but when loading, training behaves as if optimizer settings are lost. What is the most likely mistake?
medium
A. Not calling optimizer.load_state_dict() after loading checkpoint
B. Saving model.state_dict() instead of model
C. Using torch.save() instead of torch.load()
D. Not setting model.eval() before saving

Solution

  1. Step 1: Identify cause of lost optimizer settings

    If optimizer state is not loaded, training uses default optimizer settings.
  2. Step 2: Check common mistakes

    Not calling optimizer.load_state_dict() after loading checkpoint causes this issue.
  3. Final Answer:

    Not calling optimizer.load_state_dict() after loading checkpoint -> Option A
  4. Quick Check:

    Load optimizer state to keep settings = D [OK]
Hint: Always load optimizer state after loading checkpoint [OK]
Common Mistakes:
  • Saving full model instead of state_dict
  • Confusing torch.save and torch.load usage
  • Setting model.eval() affects inference, not optimizer
5. You want to save a checkpoint that allows resuming training exactly, including epoch number and best loss so far. Which is the best way to structure the checkpoint dictionary?
hard
A. {'epoch': epoch, 'model': model.state_dict()}
B. {'model': model, 'optimizer': optimizer, 'epoch': epoch}
C. {'model_state': model.state_dict(), 'loss': best_loss}
D. {'epoch': epoch, 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'best_loss': best_loss}

Solution

  1. Step 1: Identify required checkpoint components

    To resume training exactly, save epoch, model state, optimizer state, and best loss.
  2. Step 2: Evaluate options

    {'epoch': epoch, 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'best_loss': best_loss} includes all required keys with correct state_dict() usage.
  3. Final Answer:

    {'epoch': epoch, 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'best_loss': best_loss} -> Option D
  4. Quick Check:

    Save epoch, model, optimizer, loss in checkpoint = A [OK]
Hint: Include epoch, model, optimizer, and loss in checkpoint dict [OK]
Common Mistakes:
  • Saving full model or optimizer objects
  • Omitting optimizer state
  • Not saving epoch or loss for training resume