How to Use torch.save in PyTorch for Saving Models and Tensors
Use
torch.save(obj, filepath) to save a PyTorch object like a model or tensor to a file. This stores the object in a binary format that you can load later with torch.load().Syntax
The basic syntax of torch.save is:
obj: The PyTorch object to save (model, tensor, dictionary, etc.).filepath: The file path (string) where the object will be saved.- Optionally, you can specify
_use_new_zipfile_serializationfor advanced control (default is True).
python
torch.save(obj, filepath)
Example
This example shows how to save a simple tensor and a model's state dictionary using torch.save. It also shows how to load them back with torch.load.
python
import torch import torch.nn as nn # Create a tensor x = torch.tensor([1, 2, 3, 4]) # Save the tensor to a file torch.save(x, 'tensor.pt') # Define a simple model class SimpleModel(nn.Module): def __init__(self): super(SimpleModel, self).__init__() self.linear = nn.Linear(4, 2) def forward(self, input): return self.linear(input) model = SimpleModel() # Save the model's state dictionary torch.save(model.state_dict(), 'model_state.pth') # Load the tensor back loaded_x = torch.load('tensor.pt') print('Loaded tensor:', loaded_x) # Load the model state dict back loaded_model = SimpleModel() loaded_model.load_state_dict(torch.load('model_state.pth')) print('Model loaded successfully')
Output
Loaded tensor: tensor([1, 2, 3, 4])
Model loaded successfully
Common Pitfalls
- Saving entire models: Avoid saving the whole model object directly with
torch.save(model, filepath)because it can cause issues when loading if the model class code changes. Instead, save thestate_dict(). - File paths: Make sure the directory exists before saving, or you will get an error.
- Device mismatch: When loading, if the saved tensor/model was on GPU but you load on CPU, specify
map_location=torch.device('cpu')intorch.load.
python
import torch import torch.nn as nn # Wrong way: saving entire model # torch.save(model, 'model_full.pth') # Not recommended # Right way: save state dict # torch.save(model.state_dict(), 'model_state.pth')
Quick Reference
Here is a quick summary of torch.save usage:
| Function | Description |
|---|---|
| torch.save(obj, filepath) | Save a PyTorch object to a file. |
| torch.load(filepath) | Load a saved PyTorch object from a file. |
| model.state_dict() | Get model parameters to save instead of whole model. |
| model.load_state_dict(torch.load(filepath)) | Load saved parameters into model. |
| torch.load(filepath, map_location='cpu') | Load GPU tensor on CPU device. |
Key Takeaways
Use torch.save(obj, filepath) to save tensors, models, or dictionaries.
Prefer saving model.state_dict() instead of the whole model object.
Load saved objects with torch.load(filepath), specifying map_location if needed.
Ensure the save directory exists to avoid file errors.
Saving and loading state_dict allows flexible and safe model reuse.