0
0
PytorchHow-ToBeginner · 3 min read

How to Save Model in PyTorch: Syntax and Examples

In PyTorch, you save a model using torch.save() to store the model's state dictionary with model.state_dict(). To load it back, use model.load_state_dict(torch.load(PATH)).
📐

Syntax

To save a PyTorch model, use torch.save() with the model's state dictionary. To load, use torch.load() and then model.load_state_dict().

  • torch.save(obj, PATH): Saves the object obj to the file path PATH.
  • model.state_dict(): Returns the model's parameters as a dictionary.
  • torch.load(PATH): Loads the saved object from PATH.
  • model.load_state_dict(state_dict): Loads parameters into the model.
python
torch.save(model.state_dict(), PATH)

model.load_state_dict(torch.load(PATH))
💻

Example

This example shows how to save a simple neural network's parameters to a file and then load them back into a new model instance.

python
import torch
import torch.nn as nn

# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(2, 1)
    def forward(self, x):
        return self.linear(x)

# Create model and print initial weights
model = SimpleModel()
print('Initial weights:', model.linear.weight)

# Save the model state dict
PATH = 'model.pth'
torch.save(model.state_dict(), PATH)

# Create a new model instance
new_model = SimpleModel()
print('New model weights before loading:', new_model.linear.weight)

# Load the saved state dict into new model
new_model.load_state_dict(torch.load(PATH))
print('New model weights after loading:', new_model.linear.weight)
Output
Initial weights: Parameter containing: tensor([[...]], requires_grad=True) New model weights before loading: Parameter containing: tensor([[...]], requires_grad=True) New model weights after loading: Parameter containing: tensor([[...]], requires_grad=True)
⚠️

Common Pitfalls

Saving the entire model object instead of state dict: This can cause issues when loading on different machines or PyTorch versions.

Not calling model.eval() before saving for inference: This affects layers like dropout or batch norm.

Forgetting to load the state dict into the model instance: Loading the file alone does not update the model.

python
import torch
import torch.nn as nn

# Wrong way: saving entire model
# torch.save(model, PATH)  # Not recommended

# Right way: save state dict
# torch.save(model.state_dict(), PATH)

# Wrong way: loading without model instance
# state = torch.load(PATH)  # Does not update model

# Right way:
# model.load_state_dict(torch.load(PATH))
📊

Quick Reference

Remember these key points when saving and loading PyTorch models:

  • Always save model.state_dict(), not the whole model.
  • Use torch.save() and torch.load() for file operations.
  • Load weights with model.load_state_dict() into a model instance.
  • Call model.eval() before inference to set evaluation mode.

Key Takeaways

Save PyTorch models by storing their state dictionary with torch.save(model.state_dict(), PATH).
Load saved weights into a model instance using model.load_state_dict(torch.load(PATH)).
Avoid saving the entire model object to ensure compatibility and flexibility.
Call model.eval() before inference to set the model to evaluation mode.
Always create a model instance before loading saved weights.