How to Save Model in PyTorch: Syntax and Examples
In PyTorch, you save a model using
torch.save() to store the model's state dictionary with model.state_dict(). To load it back, use model.load_state_dict(torch.load(PATH)).Syntax
To save a PyTorch model, use torch.save() with the model's state dictionary. To load, use torch.load() and then model.load_state_dict().
- torch.save(obj, PATH): Saves the object
objto the file pathPATH. - model.state_dict(): Returns the model's parameters as a dictionary.
- torch.load(PATH): Loads the saved object from
PATH. - model.load_state_dict(state_dict): Loads parameters into the model.
python
torch.save(model.state_dict(), PATH) model.load_state_dict(torch.load(PATH))
Example
This example shows how to save a simple neural network's parameters to a file and then load them back into a new model instance.
python
import torch import torch.nn as nn # Define a simple model class SimpleModel(nn.Module): def __init__(self): super(SimpleModel, self).__init__() self.linear = nn.Linear(2, 1) def forward(self, x): return self.linear(x) # Create model and print initial weights model = SimpleModel() print('Initial weights:', model.linear.weight) # Save the model state dict PATH = 'model.pth' torch.save(model.state_dict(), PATH) # Create a new model instance new_model = SimpleModel() print('New model weights before loading:', new_model.linear.weight) # Load the saved state dict into new model new_model.load_state_dict(torch.load(PATH)) print('New model weights after loading:', new_model.linear.weight)
Output
Initial weights: Parameter containing:
tensor([[...]], requires_grad=True)
New model weights before loading: Parameter containing:
tensor([[...]], requires_grad=True)
New model weights after loading: Parameter containing:
tensor([[...]], requires_grad=True)
Common Pitfalls
Saving the entire model object instead of state dict: This can cause issues when loading on different machines or PyTorch versions.
Not calling model.eval() before saving for inference: This affects layers like dropout or batch norm.
Forgetting to load the state dict into the model instance: Loading the file alone does not update the model.
python
import torch import torch.nn as nn # Wrong way: saving entire model # torch.save(model, PATH) # Not recommended # Right way: save state dict # torch.save(model.state_dict(), PATH) # Wrong way: loading without model instance # state = torch.load(PATH) # Does not update model # Right way: # model.load_state_dict(torch.load(PATH))
Quick Reference
Remember these key points when saving and loading PyTorch models:
- Always save
model.state_dict(), not the whole model. - Use
torch.save()andtorch.load()for file operations. - Load weights with
model.load_state_dict()into a model instance. - Call
model.eval()before inference to set evaluation mode.
Key Takeaways
Save PyTorch models by storing their state dictionary with torch.save(model.state_dict(), PATH).
Load saved weights into a model instance using model.load_state_dict(torch.load(PATH)).
Avoid saving the entire model object to ensure compatibility and flexibility.
Call model.eval() before inference to set the model to evaluation mode.
Always create a model instance before loading saved weights.