How to Load a Model in PyTorch: Syntax and Example
To load a model in PyTorch, use
torch.load() to load the saved state dictionary and then apply model.load_state_dict() to update the model weights. Finally, call model.eval() to set the model to evaluation mode for inference.Syntax
Loading a model in PyTorch usually involves these steps:
torch.load(PATH): Loads the saved state dictionary from the file.model.load_state_dict(state_dict): Loads the weights into the model architecture.model.eval(): Sets the model to evaluation mode, turning off training-specific layers like dropout.
python
state_dict = torch.load(PATH) model.load_state_dict(state_dict) model.eval()
Example
This example shows how to define a simple model, save its weights, and then load them back for inference.
python
import torch import torch.nn as nn # Define a simple model class SimpleModel(nn.Module): def __init__(self): super(SimpleModel, self).__init__() self.linear = nn.Linear(2, 1) def forward(self, x): return self.linear(x) # Create and save the model model = SimpleModel() PATH = 'simple_model.pth' torch.save(model.state_dict(), PATH) # Load the model loaded_model = SimpleModel() loaded_model.load_state_dict(torch.load(PATH)) loaded_model.eval() # Test the loaded model input_tensor = torch.tensor([[1.0, 2.0]]) output = loaded_model(input_tensor) print(output)
Output
tensor([[0.1234]], grad_fn=<AddmmBackward0>)
Common Pitfalls
Common mistakes when loading models in PyTorch include:
- Not matching the model architecture before loading weights causes errors.
- Forgetting to call
model.eval()leads to incorrect inference results because layers like dropout stay active. - Loading the entire model with
torch.load()instead of just the state dictionary can cause issues if the model class is not available.
python
import torch import torch.nn as nn # Wrong way: loading entire model without class definition # loaded_model = torch.load('model.pth') # This can fail if class is missing # Right way: class SimpleModel(nn.Module): def __init__(self): super(SimpleModel, self).__init__() self.linear = nn.Linear(2, 1) def forward(self, x): return self.linear(x) model = SimpleModel() model.load_state_dict(torch.load('model.pth')) model.eval()
Quick Reference
| Step | Function | Purpose |
|---|---|---|
| 1 | torch.load(PATH) | Load saved weights from file |
| 2 | model.load_state_dict(state_dict) | Load weights into model |
| 3 | model.eval() | Set model to evaluation mode for inference |
Key Takeaways
Always load the saved state dictionary with torch.load and then apply it to the model with load_state_dict.
Call model.eval() after loading to ensure correct behavior during inference.
Make sure the model architecture matches the saved weights before loading.
Avoid saving and loading the entire model object; prefer saving state_dict for flexibility.
Loading weights without the model class definition will cause errors.