How to Save Model State Dict in PyTorch: Simple Guide
To save a PyTorch model's parameters, use
torch.save(model.state_dict(), filepath). This saves only the model's learned weights and biases, which you can later load with model.load_state_dict().Syntax
The basic syntax to save a model's state dictionary in PyTorch is:
torch.save(model.state_dict(), filepath): Saves the model's parameters to the specified file path.model.state_dict(): Returns a dictionary containing all the model's parameters.filepath: A string path where the state dict will be saved, usually ending with.ptor.pth.
python
torch.save(model.state_dict(), 'model_weights.pth')Example
This example shows how to define a simple model, save its state dict, and then load it back into a new model instance.
python
import torch import torch.nn as nn # Define a simple model class SimpleModel(nn.Module): def __init__(self): super(SimpleModel, self).__init__() self.linear = nn.Linear(2, 1) def forward(self, x): return self.linear(x) # Create model instance and print initial weights model = SimpleModel() print('Initial weights:', model.linear.weight) # Save the model's state dict torch.save(model.state_dict(), 'model_weights.pth') # Create a new model instance new_model = SimpleModel() print('New model weights before loading:', new_model.linear.weight) # Load the saved state dict into the new model new_model.load_state_dict(torch.load('model_weights.pth')) print('New model weights after loading:', new_model.linear.weight)
Output
Initial weights: Parameter containing:
tensor([[...]], requires_grad=True)
New model weights before loading: Parameter containing:
tensor([[...]], requires_grad=True)
New model weights after loading: Parameter containing:
tensor([[...]], requires_grad=True)
Common Pitfalls
Common mistakes when saving and loading model state dicts include:
- Saving the entire model object instead of just the state dict, which can cause issues when loading on different machines or PyTorch versions.
- Not matching the model architecture when loading the state dict, leading to errors.
- Forgetting to call
model.eval()after loading if you want to use the model for inference.
python
import torch import torch.nn as nn # Wrong way: saving entire model # torch.save(model, 'full_model.pth') # Not recommended # Right way: save only state dict # torch.save(model.state_dict(), 'model_weights.pth') # When loading, make sure to create the model first # model = SimpleModel() # model.load_state_dict(torch.load('model_weights.pth'))
Quick Reference
Remember these key points when saving and loading PyTorch model weights:
- Use
torch.save(model.state_dict(), filepath)to save. - Load with
model.load_state_dict(torch.load(filepath)). - Ensure model architecture matches before loading.
- Call
model.eval()for evaluation mode after loading.
Key Takeaways
Use torch.save(model.state_dict(), filepath) to save only model parameters.
Load saved weights with model.load_state_dict(torch.load(filepath)) after creating the model.
Ensure the model architecture matches when loading state dicts to avoid errors.
Avoid saving the entire model object to keep compatibility and flexibility.
Call model.eval() after loading for inference to disable dropout and batch norm updates.