How to Load Model State Dict in PyTorch: Simple Guide
To load a saved model state dictionary in PyTorch, use
model.load_state_dict(torch.load(PATH)). This restores the model's learned parameters from the saved file at PATH.Syntax
The main function to load a model's saved parameters is model.load_state_dict(). You pass it the loaded state dictionary from a file using torch.load(PATH).
model: Your PyTorch model instance.load_state_dict(): Method to load parameters into the model.torch.load(PATH): Loads the saved state dictionary from disk.PATH: File path to the saved state dict (usually a .pt or .pth file).
python
model.load_state_dict(torch.load(PATH))
Example
This example shows how to define a simple model, save its state dict, and then load it back to restore the model's parameters.
python
import torch import torch.nn as nn # Define a simple model class SimpleModel(nn.Module): def __init__(self): super(SimpleModel, self).__init__() self.linear = nn.Linear(2, 1) def forward(self, x): return self.linear(x) # Create model instance and print initial weights model = SimpleModel() print('Initial weights:', model.linear.weight) # Save the model state dict PATH = 'model_state.pth' torch.save(model.state_dict(), PATH) # Create a new model instance new_model = SimpleModel() print('Weights before loading:', new_model.linear.weight) # Load the saved state dict into the new model new_model.load_state_dict(torch.load(PATH)) print('Weights after loading:', new_model.linear.weight)
Output
Initial weights: Parameter containing:
tensor([[...]], requires_grad=True)
Weights before loading: Parameter containing:
tensor([[...]], requires_grad=True)
Weights after loading: Parameter containing:
tensor([[...]], requires_grad=True)
Common Pitfalls
Common mistakes when loading a state dict include:
- Not matching the model architecture exactly with the saved state dict.
- Forgetting to call
model.eval()after loading if you want to run inference. - Loading a state dict saved on GPU to a CPU model without specifying
map_location.
Example of loading a GPU saved model on CPU:
python
# Wrong way (may cause error if saved on GPU but loading on CPU) # model.load_state_dict(torch.load(PATH)) # Correct way for CPU loading model.load_state_dict(torch.load(PATH, map_location=torch.device('cpu')))
Quick Reference
- Use
torch.save(model.state_dict(), PATH)to save model parameters. - Use
model.load_state_dict(torch.load(PATH))to load parameters. - Call
model.eval()after loading for evaluation mode. - Use
map_locationif loading on different device than saved.
Key Takeaways
Use model.load_state_dict(torch.load(PATH)) to load saved model parameters.
Ensure model architecture matches the saved state dict exactly.
Use map_location=torch.device('cpu') when loading GPU models on CPU.
Call model.eval() after loading for inference to disable dropout/batchnorm training behavior.
Save and load only the state dict, not the entire model object, for flexibility.