How to Use torch.load in PyTorch: Load Models and Tensors
Use
torch.load(path) to load a saved PyTorch object like a model or tensor from a file. It reads the file and returns the saved data, which you can assign to a variable for further use.Syntax
The basic syntax of torch.load is:
torch.load(f, map_location=None, pickle_module=pickle, weights_only=False, **pickle_load_args)
Here:
f: The file path or file-like object to load from.map_location: Optional device mapping (e.g., 'cpu' or 'cuda') to load tensors onto a specific device.pickle_module: Module used for unpickling, default is Python'spicklemodule.weights_only: If True, loads only weights from a checkpoint dictionary.
python
torch.load(f, map_location=None, pickle_module=pickle, weights_only=False, **pickle_load_args)
Example
This example shows how to save a tensor and then load it back using torch.load. It demonstrates loading the saved tensor and printing its value.
python
import torch # Create a tensor x = torch.tensor([1, 2, 3, 4]) # Save the tensor to a file torch.save(x, 'tensor.pt') # Load the tensor from the file loaded_x = torch.load('tensor.pt') print('Loaded tensor:', loaded_x)
Output
Loaded tensor: tensor([1, 2, 3, 4])
Common Pitfalls
Common mistakes when using torch.load include:
- Trying to load a file that does not exist or has a wrong path causes a
FileNotFoundError. - Loading a model saved on GPU directly on CPU without
map_location='cpu'causes errors. - Confusing
torch.loadwithtorch.save— one loads, the other saves.
Example of loading a GPU model on CPU correctly:
python
import torch # Correct way to load GPU saved model on CPU device = 'cpu' model = torch.load('model_gpu.pth', map_location=device)
Quick Reference
| Parameter | Description |
|---|---|
| f | File path or object to load from |
| map_location | Device to map loaded tensors (e.g., 'cpu', 'cuda:0') |
| pickle_module | Module for unpickling, default is pickle |
| weights_only | Load only weights from checkpoint if True |
Key Takeaways
Use torch.load(path) to load saved PyTorch objects like models or tensors.
Use map_location='cpu' to load GPU-trained models on a CPU machine.
Always ensure the file path is correct to avoid file not found errors.
torch.load returns the saved object; assign it to a variable to use it.
torch.load complements torch.save, which is used to save objects.