0
0
PytorchHow-ToBeginner · 3 min read

How to Resume Training from Checkpoint in PyTorch

To resume training in PyTorch, load the saved checkpoint using torch.load(), then restore the model state with model.load_state_dict() and optimizer state with optimizer.load_state_dict(). Also, restore the training epoch and any other states to continue training seamlessly.
📐

Syntax

Here is the typical syntax to resume training from a checkpoint in PyTorch:

  • checkpoint = torch.load(PATH): Load the checkpoint file.
  • model.load_state_dict(checkpoint['model_state_dict']): Restore model weights.
  • optimizer.load_state_dict(checkpoint['optimizer_state_dict']): Restore optimizer state.
  • epoch = checkpoint['epoch']: Retrieve last completed epoch.
  • Continue training from epoch + 1.
python
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
# Continue training from epoch + 1
💻

Example

This example shows how to save a checkpoint during training and then resume training from that checkpoint.

python
import torch
import torch.nn as nn
import torch.optim as optim

# Simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 1)
    def forward(self, x):
        return self.linear(x)

model = SimpleModel()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Simulate training and save checkpoint
for epoch in range(3):
    # Dummy training step
    inputs = torch.randn(5, 10)
    outputs = model(inputs)
    loss = outputs.sum()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Save checkpoint
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, 'checkpoint.pth')

print('Checkpoint saved at epoch 2')

# Later: load checkpoint and resume
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch'] + 1

print(f'Resuming training from epoch {start_epoch}')
Output
Checkpoint saved at epoch 2 Resuming training from epoch 3
⚠️

Common Pitfalls

  • Not loading the optimizer state causes learning rate and momentum to reset, affecting training.
  • Forgetting to set the model to train() mode after loading checkpoint can cause issues with layers like dropout or batchnorm.
  • Loading checkpoint on a different device without specifying map_location can cause errors.
  • Not resuming the epoch count leads to overwriting previous training progress.
python
## Wrong way (missing optimizer load and epoch resume)
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])  # Missing
# epoch = checkpoint['epoch']  # Missing

## Right way
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
model.train()  # Important to set training mode
📊

Quick Reference

Remember these key steps when resuming training from a checkpoint:

  • Load checkpoint with torch.load().
  • Restore model and optimizer states.
  • Retrieve and use the saved epoch number.
  • Set model to train() mode before continuing.
  • Handle device mapping if loading on different hardware.

Key Takeaways

Always load both model and optimizer states to resume training correctly.
Restore the saved epoch to continue training without overwriting progress.
Set the model to train mode after loading checkpoint to enable proper behavior.
Use map_location in torch.load if loading checkpoint on a different device.
Save checkpoints regularly during training to avoid losing progress.