0
0
PytorchHow-ToBeginner · 3 min read

How to Save Checkpoint in PyTorch: Syntax and Example

To save a checkpoint in PyTorch, use torch.save() to store the model's state dictionary or the entire model. Typically, save model.state_dict() to keep only the learned parameters, which can be loaded later for inference or training continuation.
📐

Syntax

The basic syntax to save a checkpoint in PyTorch is:

  • torch.save(obj, filepath): Saves the object obj to the file path filepath.
  • obj is usually model.state_dict() to save model weights.
  • You can also save a dictionary containing model state, optimizer state, and other info.
python
torch.save(model.state_dict(), 'checkpoint.pth')
💻

Example

This example shows how to save a simple model's weights during training as a checkpoint file named checkpoint.pth.

python
import torch
import torch.nn as nn

# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(10, 2)

    def forward(self, x):
        return self.linear(x)

model = SimpleModel()

# Simulate training step (dummy)
input_tensor = torch.randn(1, 10)
output = model(input_tensor)

# Save checkpoint
torch.save(model.state_dict(), 'checkpoint.pth')

print('Checkpoint saved as checkpoint.pth')
Output
Checkpoint saved as checkpoint.pth
⚠️

Common Pitfalls

Common mistakes when saving checkpoints in PyTorch:

  • Saving the entire model object instead of state_dict() can cause issues when loading on different devices or PyTorch versions.
  • Not saving optimizer state if you want to resume training exactly.
  • Using relative paths without ensuring the directory exists can cause errors.
  • Forgetting to call model.eval() before inference after loading checkpoint.
python
import torch

# Wrong way: saving entire model
# torch.save(model, 'model_full.pth')  # Not recommended

# Right way: save state_dict
# torch.save(model.state_dict(), 'model_weights.pth')
📊

Quick Reference

ActionCode Example
Save model weightstorch.save(model.state_dict(), 'checkpoint.pth')
Save model and optimizer statestorch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict()}, 'checkpoint.pth')
Load model weightsmodel.load_state_dict(torch.load('checkpoint.pth'))
Load checkpoint with optimizercheckpoint = torch.load('checkpoint.pth') model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer'])

Key Takeaways

Use torch.save(model.state_dict(), filepath) to save model weights efficiently.
Saving optimizer state along with model state helps resume training seamlessly.
Avoid saving the entire model object to prevent compatibility issues.
Always verify the checkpoint path and directory exist before saving.
Load checkpoints with model.load_state_dict(torch.load(filepath)) for correct restoration.