0
0
PyTorchml~5 mins

Saving model state_dict in PyTorch

Choose your learning style9 modes available
Introduction
Saving the model's state_dict lets you keep the learned settings so you can use or improve the model later without starting over.
You want to pause training and continue later without losing progress.
You finished training and want to save the model to make predictions later.
You want to share your trained model with others.
You want to keep different versions of your model during experiments.
Syntax
PyTorch
torch.save(model.state_dict(), 'filename.pth')
The state_dict is a dictionary that holds all the model's parameters (weights and biases).
The filename usually ends with '.pth' or '.pt' to indicate a PyTorch model file.
Examples
Save the model's parameters to a file named 'model_weights.pth'.
PyTorch
torch.save(model.state_dict(), 'model_weights.pth')
Save the model's parameters to a variable path for flexibility.
PyTorch
PATH = 'checkpoint.pth'
torch.save(model.state_dict(), PATH)
Sample Model
This code creates a simple model, prints its initial weights, saves the weights to a file, then loads them into a new model and prints the loaded weights to confirm they match.
PyTorch
import torch
import torch.nn as nn

# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(2, 1)
    def forward(self, x):
        return self.linear(x)

# Create model instance
model = SimpleModel()

# Print initial weights
print('Initial weights:', model.linear.weight)

# Save the model's state_dict
torch.save(model.state_dict(), 'simple_model.pth')

# Load the state_dict into a new model to check
new_model = SimpleModel()
new_model.load_state_dict(torch.load('simple_model.pth'))

# Print loaded weights
print('Loaded weights:', new_model.linear.weight)
OutputSuccess
Important Notes
Always save the state_dict, not the whole model object, for better flexibility and compatibility.
When loading, create the model architecture first, then load the saved state_dict.
Use consistent file paths and names to avoid confusion.
Summary
Saving model state_dict stores only the model's learned parameters.
Use torch.save(model.state_dict(), 'file.pth') to save.
Load saved weights into the same model architecture with load_state_dict.