0
0
PytorchHow-ToBeginner · 3 min read

How to Use torch.save in PyTorch for Saving Models and Tensors

Use torch.save(obj, filepath) to save a PyTorch object like a model or tensor to a file. This stores the object in a binary format that you can load later with torch.load().
📐

Syntax

The basic syntax of torch.save is:

  • obj: The PyTorch object to save (model, tensor, dictionary, etc.).
  • filepath: The file path (string) where the object will be saved.
  • Optionally, you can specify _use_new_zipfile_serialization for advanced control (default is True).
python
torch.save(obj, filepath)
💻

Example

This example shows how to save a simple tensor and a model's state dictionary using torch.save. It also shows how to load them back with torch.load.

python
import torch
import torch.nn as nn

# Create a tensor
x = torch.tensor([1, 2, 3, 4])

# Save the tensor to a file
torch.save(x, 'tensor.pt')

# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(4, 2)
    def forward(self, input):
        return self.linear(input)

model = SimpleModel()

# Save the model's state dictionary
torch.save(model.state_dict(), 'model_state.pth')

# Load the tensor back
loaded_x = torch.load('tensor.pt')
print('Loaded tensor:', loaded_x)

# Load the model state dict back
loaded_model = SimpleModel()
loaded_model.load_state_dict(torch.load('model_state.pth'))
print('Model loaded successfully')
Output
Loaded tensor: tensor([1, 2, 3, 4]) Model loaded successfully
⚠️

Common Pitfalls

  • Saving entire models: Avoid saving the whole model object directly with torch.save(model, filepath) because it can cause issues when loading if the model class code changes. Instead, save the state_dict().
  • File paths: Make sure the directory exists before saving, or you will get an error.
  • Device mismatch: When loading, if the saved tensor/model was on GPU but you load on CPU, specify map_location=torch.device('cpu') in torch.load.
python
import torch
import torch.nn as nn

# Wrong way: saving entire model
# torch.save(model, 'model_full.pth')  # Not recommended

# Right way: save state dict
# torch.save(model.state_dict(), 'model_state.pth')
📊

Quick Reference

Here is a quick summary of torch.save usage:

FunctionDescription
torch.save(obj, filepath)Save a PyTorch object to a file.
torch.load(filepath)Load a saved PyTorch object from a file.
model.state_dict()Get model parameters to save instead of whole model.
model.load_state_dict(torch.load(filepath))Load saved parameters into model.
torch.load(filepath, map_location='cpu')Load GPU tensor on CPU device.

Key Takeaways

Use torch.save(obj, filepath) to save tensors, models, or dictionaries.
Prefer saving model.state_dict() instead of the whole model object.
Load saved objects with torch.load(filepath), specifying map_location if needed.
Ensure the save directory exists to avoid file errors.
Saving and loading state_dict allows flexible and safe model reuse.