How to Resume Training from Checkpoint in PyTorch
To resume training in PyTorch, load the saved checkpoint using
torch.load(), then restore the model state with model.load_state_dict() and optimizer state with optimizer.load_state_dict(). Also, restore the training epoch and any other states to continue training seamlessly.Syntax
Here is the typical syntax to resume training from a checkpoint in PyTorch:
checkpoint = torch.load(PATH): Load the checkpoint file.model.load_state_dict(checkpoint['model_state_dict']): Restore model weights.optimizer.load_state_dict(checkpoint['optimizer_state_dict']): Restore optimizer state.epoch = checkpoint['epoch']: Retrieve last completed epoch.- Continue training from
epoch + 1.
python
checkpoint = torch.load(PATH) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] # Continue training from epoch + 1
Example
This example shows how to save a checkpoint during training and then resume training from that checkpoint.
python
import torch import torch.nn as nn import torch.optim as optim # Simple model class SimpleModel(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(10, 1) def forward(self, x): return self.linear(x) model = SimpleModel() optimizer = optim.SGD(model.parameters(), lr=0.01) # Simulate training and save checkpoint for epoch in range(3): # Dummy training step inputs = torch.randn(5, 10) outputs = model(inputs) loss = outputs.sum() optimizer.zero_grad() loss.backward() optimizer.step() # Save checkpoint torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict() }, 'checkpoint.pth') print('Checkpoint saved at epoch 2') # Later: load checkpoint and resume checkpoint = torch.load('checkpoint.pth') model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) start_epoch = checkpoint['epoch'] + 1 print(f'Resuming training from epoch {start_epoch}')
Output
Checkpoint saved at epoch 2
Resuming training from epoch 3
Common Pitfalls
- Not loading the optimizer state causes learning rate and momentum to reset, affecting training.
- Forgetting to set the model to
train()mode after loading checkpoint can cause issues with layers like dropout or batchnorm. - Loading checkpoint on a different device without specifying
map_locationcan cause errors. - Not resuming the epoch count leads to overwriting previous training progress.
python
## Wrong way (missing optimizer load and epoch resume) checkpoint = torch.load('checkpoint.pth') model.load_state_dict(checkpoint['model_state_dict']) # optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # Missing # epoch = checkpoint['epoch'] # Missing ## Right way checkpoint = torch.load('checkpoint.pth') model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] model.train() # Important to set training mode
Quick Reference
Remember these key steps when resuming training from a checkpoint:
- Load checkpoint with
torch.load(). - Restore model and optimizer states.
- Retrieve and use the saved epoch number.
- Set model to
train()mode before continuing. - Handle device mapping if loading on different hardware.
Key Takeaways
Always load both model and optimizer states to resume training correctly.
Restore the saved epoch to continue training without overwriting progress.
Set the model to train mode after loading checkpoint to enable proper behavior.
Use map_location in torch.load if loading checkpoint on a different device.
Save checkpoints regularly during training to avoid losing progress.