Saving a checkpoint with the optimizer state lets you pause and continue training later without losing progress.
Checkpoint with optimizer state in PyTorch
Start learning this pattern below
Jump into concepts and practice - no test required
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch,
'loss': loss_value
}, PATH)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss_value = checkpoint['loss']Use model.state_dict() and optimizer.state_dict() to get their states.
Loading optimizer state restores learning rates and momentum for smooth training continuation.
checkpoint.pth.torch.save({'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict()}, 'checkpoint.pth')checkpoint = torch.load('checkpoint.pth') model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
torch.save({'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'epoch': current_epoch}, 'checkpoint.pth')This code trains a simple model for one step, saves the model and optimizer states along with epoch and loss, then loads them back and prints the saved epoch and loss.
import torch import torch.nn as nn import torch.optim as optim # Simple model class SimpleNet(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(2, 1) def forward(self, x): return self.linear(x) # Create model and optimizer model = SimpleNet() optimizer = optim.SGD(model.parameters(), lr=0.1) # Dummy data inputs = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) targets = torch.tensor([[1.0], [2.0]]) # Loss function criterion = nn.MSELoss() # Training step model.train() optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.step() # Save checkpoint checkpoint_path = 'checkpoint.pth' torch.save({ 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'epoch': 1, 'loss': loss.item() }, checkpoint_path) # Load checkpoint checkpoint = torch.load(checkpoint_path) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) loaded_epoch = checkpoint['epoch'] loaded_loss = checkpoint['loss'] print(f"Loaded epoch: {loaded_epoch}") print(f"Loaded loss: {loaded_loss:.4f}")
Always save both model and optimizer states to continue training smoothly.
Include extra info like epoch and loss to track training progress.
Use torch.load(PATH, map_location=torch.device('cpu')) if loading on CPU from GPU-trained model.
Checkpoint saves model and optimizer states to pause and resume training.
Loading optimizer state restores training settings like learning rate.
Include epoch and loss in checkpoint to track training progress.
Practice
Solution
Step 1: Understand what optimizer state contains
The optimizer state includes parameters like learning rate, momentum, and other variables that affect training progress.Step 2: Reason why saving optimizer state is important
Saving the optimizer state allows training to resume exactly where it left off, preserving these settings.Final Answer:
To resume training with the same learning rate and momentum settings -> Option CQuick Check:
Optimizer state saves training settings = C [OK]
- Thinking optimizer state reduces model size
- Confusing optimizer state with model weights
- Believing optimizer state affects inference speed
Solution
Step 1: Identify correct saving method for states
PyTorch recommends saving state_dict() of model and optimizer for checkpoints.Step 2: Check each option
torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict()}, 'checkpoint.pth') saves state_dict() of both model and optimizer in a dictionary, which is correct.Final Answer:
torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict()}, 'checkpoint.pth') -> Option BQuick Check:
Save state_dict() for model and optimizer = B [OK]
- Saving full model object instead of state_dict
- Saving optimizer object directly
- Not saving optimizer state at all
import torch
import torch.nn as nn
import torch.optim as optim
model = nn.Linear(2, 1)
optimizer = optim.SGD(model.parameters(), lr=0.1)
# Save checkpoint
checkpoint = {'model': model.state_dict(), 'optimizer': optimizer.state_dict()}
torch.save(checkpoint, 'cp.pth')
# Load checkpoint
loaded = torch.load('cp.pth')
optimizer.load_state_dict(loaded['optimizer'])
print(optimizer.param_groups[0]['lr'])Solution
Step 1: Understand optimizer initialization
Optimizer is created with learning rate 0.1 and saved in checkpoint.Step 2: Loading optimizer state restores learning rate
Loading optimizer state_dict sets learning rate back to 0.1.Final Answer:
0.1 -> Option AQuick Check:
Loaded optimizer lr = 0.1 [OK]
- Assuming learning rate resets to default
- Forgetting to load optimizer state
- Confusing model and optimizer states
Solution
Step 1: Identify cause of lost optimizer settings
If optimizer state is not loaded, training uses default optimizer settings.Step 2: Check common mistakes
Not calling optimizer.load_state_dict() after loading checkpoint causes this issue.Final Answer:
Not calling optimizer.load_state_dict() after loading checkpoint -> Option AQuick Check:
Load optimizer state to keep settings = D [OK]
- Saving full model instead of state_dict
- Confusing torch.save and torch.load usage
- Setting model.eval() affects inference, not optimizer
Solution
Step 1: Identify required checkpoint components
To resume training exactly, save epoch, model state, optimizer state, and best loss.Step 2: Evaluate options
{'epoch': epoch, 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'best_loss': best_loss} includes all required keys with correct state_dict() usage.Final Answer:
{'epoch': epoch, 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'best_loss': best_loss} -> Option DQuick Check:
Save epoch, model, optimizer, loss in checkpoint = A [OK]
- Saving full model or optimizer objects
- Omitting optimizer state
- Not saving epoch or loss for training resume
