How to Save Checkpoint in PyTorch: Syntax and Example
To save a checkpoint in PyTorch, use
torch.save() to store the model's state dictionary or the entire model. Typically, save model.state_dict() to keep only the learned parameters, which can be loaded later for inference or training continuation.Syntax
The basic syntax to save a checkpoint in PyTorch is:
torch.save(obj, filepath): Saves the objectobjto the file pathfilepath.objis usuallymodel.state_dict()to save model weights.- You can also save a dictionary containing model state, optimizer state, and other info.
python
torch.save(model.state_dict(), 'checkpoint.pth')Example
This example shows how to save a simple model's weights during training as a checkpoint file named checkpoint.pth.
python
import torch import torch.nn as nn # Define a simple model class SimpleModel(nn.Module): def __init__(self): super(SimpleModel, self).__init__() self.linear = nn.Linear(10, 2) def forward(self, x): return self.linear(x) model = SimpleModel() # Simulate training step (dummy) input_tensor = torch.randn(1, 10) output = model(input_tensor) # Save checkpoint torch.save(model.state_dict(), 'checkpoint.pth') print('Checkpoint saved as checkpoint.pth')
Output
Checkpoint saved as checkpoint.pth
Common Pitfalls
Common mistakes when saving checkpoints in PyTorch:
- Saving the entire model object instead of
state_dict()can cause issues when loading on different devices or PyTorch versions. - Not saving optimizer state if you want to resume training exactly.
- Using relative paths without ensuring the directory exists can cause errors.
- Forgetting to call
model.eval()before inference after loading checkpoint.
python
import torch # Wrong way: saving entire model # torch.save(model, 'model_full.pth') # Not recommended # Right way: save state_dict # torch.save(model.state_dict(), 'model_weights.pth')
Quick Reference
| Action | Code Example |
|---|---|
| Save model weights | torch.save(model.state_dict(), 'checkpoint.pth') |
| Save model and optimizer states | torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict()}, 'checkpoint.pth') |
| Load model weights | model.load_state_dict(torch.load('checkpoint.pth')) |
| Load checkpoint with optimizer | checkpoint = torch.load('checkpoint.pth') model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) |
Key Takeaways
Use torch.save(model.state_dict(), filepath) to save model weights efficiently.
Saving optimizer state along with model state helps resume training seamlessly.
Avoid saving the entire model object to prevent compatibility issues.
Always verify the checkpoint path and directory exist before saving.
Load checkpoints with model.load_state_dict(torch.load(filepath)) for correct restoration.