0
0
PyTorchml~5 mins

Loading model state_dict in PyTorch

Choose your learning style9 modes available
Introduction

Loading a model's state_dict lets you reuse a saved model's learned settings. This saves time and helps continue training or make predictions.

You want to continue training a model from where you left off.
You want to use a pretrained model to make predictions on new data.
You want to share a trained model with someone else.
You want to test a model saved earlier without retraining.
You want to fine-tune a model on a new dataset.
Syntax
PyTorch
model.load_state_dict(torch.load(PATH))

model is your PyTorch model instance.

PATH is the file path to the saved state_dict.

Examples
Loads weights from 'model_weights.pth' into the model.
PyTorch
model.load_state_dict(torch.load('model_weights.pth'))
Loads the state_dict first, then applies it to the model.
PyTorch
state_dict = torch.load('weights.pth')
model.load_state_dict(state_dict)
Loads weights but allows missing or extra keys without error.
PyTorch
model.load_state_dict(torch.load('model.pth'), strict=False)
Sample Model

This program creates a simple linear model, saves its weights, changes them to zero, then reloads the saved weights to restore the original values.

PyTorch
import torch
import torch.nn as nn

# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(2, 1)

    def forward(self, x):
        return self.linear(x)

# Create model instance and print initial weights
model = SimpleModel()
print('Initial weights:', model.linear.weight)

# Save the model's state_dict
PATH = 'simple_model.pth'
torch.save(model.state_dict(), PATH)

# Change weights manually
with torch.no_grad():
    model.linear.weight.fill_(0)
print('Weights after zeroing:', model.linear.weight)

# Load saved weights back
model.load_state_dict(torch.load(PATH))
print('Weights after loading:', model.linear.weight)
OutputSuccess
Important Notes

Always use the same model architecture when loading a saved state_dict.

Use strict=False if your saved weights don't exactly match the model.

State_dict only saves parameters, not the entire model code.

Summary

Loading a state_dict restores saved model weights.

Use torch.load(PATH) to load the saved weights.

Apply weights with model.load_state_dict().