0
0
PytorchHow-ToBeginner · 3 min read

How to Load Model State Dict in PyTorch: Simple Guide

To load a saved model state dictionary in PyTorch, use model.load_state_dict(torch.load(PATH)). This restores the model's learned parameters from the saved file at PATH.
📐

Syntax

The main function to load a model's saved parameters is model.load_state_dict(). You pass it the loaded state dictionary from a file using torch.load(PATH).

  • model: Your PyTorch model instance.
  • load_state_dict(): Method to load parameters into the model.
  • torch.load(PATH): Loads the saved state dictionary from disk.
  • PATH: File path to the saved state dict (usually a .pt or .pth file).
python
model.load_state_dict(torch.load(PATH))
💻

Example

This example shows how to define a simple model, save its state dict, and then load it back to restore the model's parameters.

python
import torch
import torch.nn as nn

# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(2, 1)

    def forward(self, x):
        return self.linear(x)

# Create model instance and print initial weights
model = SimpleModel()
print('Initial weights:', model.linear.weight)

# Save the model state dict
PATH = 'model_state.pth'
torch.save(model.state_dict(), PATH)

# Create a new model instance
new_model = SimpleModel()
print('Weights before loading:', new_model.linear.weight)

# Load the saved state dict into the new model
new_model.load_state_dict(torch.load(PATH))
print('Weights after loading:', new_model.linear.weight)
Output
Initial weights: Parameter containing: tensor([[...]], requires_grad=True) Weights before loading: Parameter containing: tensor([[...]], requires_grad=True) Weights after loading: Parameter containing: tensor([[...]], requires_grad=True)
⚠️

Common Pitfalls

Common mistakes when loading a state dict include:

  • Not matching the model architecture exactly with the saved state dict.
  • Forgetting to call model.eval() after loading if you want to run inference.
  • Loading a state dict saved on GPU to a CPU model without specifying map_location.

Example of loading a GPU saved model on CPU:

python
# Wrong way (may cause error if saved on GPU but loading on CPU)
# model.load_state_dict(torch.load(PATH))

# Correct way for CPU loading
model.load_state_dict(torch.load(PATH, map_location=torch.device('cpu')))
📊

Quick Reference

  • Use torch.save(model.state_dict(), PATH) to save model parameters.
  • Use model.load_state_dict(torch.load(PATH)) to load parameters.
  • Call model.eval() after loading for evaluation mode.
  • Use map_location if loading on different device than saved.

Key Takeaways

Use model.load_state_dict(torch.load(PATH)) to load saved model parameters.
Ensure model architecture matches the saved state dict exactly.
Use map_location=torch.device('cpu') when loading GPU models on CPU.
Call model.eval() after loading for inference to disable dropout/batchnorm training behavior.
Save and load only the state dict, not the entire model object, for flexibility.