0
0
PytorchHow-ToBeginner · 3 min read

How to Use torch.load in PyTorch: Load Models and Tensors

Use torch.load(path) to load a saved PyTorch object like a model or tensor from a file. It reads the file and returns the saved data, which you can assign to a variable for further use.
📐

Syntax

The basic syntax of torch.load is:

  • torch.load(f, map_location=None, pickle_module=pickle, weights_only=False, **pickle_load_args)

Here:

  • f: The file path or file-like object to load from.
  • map_location: Optional device mapping (e.g., 'cpu' or 'cuda') to load tensors onto a specific device.
  • pickle_module: Module used for unpickling, default is Python's pickle module.
  • weights_only: If True, loads only weights from a checkpoint dictionary.
python
torch.load(f, map_location=None, pickle_module=pickle, weights_only=False, **pickle_load_args)
💻

Example

This example shows how to save a tensor and then load it back using torch.load. It demonstrates loading the saved tensor and printing its value.

python
import torch

# Create a tensor
x = torch.tensor([1, 2, 3, 4])

# Save the tensor to a file
torch.save(x, 'tensor.pt')

# Load the tensor from the file
loaded_x = torch.load('tensor.pt')

print('Loaded tensor:', loaded_x)
Output
Loaded tensor: tensor([1, 2, 3, 4])
⚠️

Common Pitfalls

Common mistakes when using torch.load include:

  • Trying to load a file that does not exist or has a wrong path causes a FileNotFoundError.
  • Loading a model saved on GPU directly on CPU without map_location='cpu' causes errors.
  • Confusing torch.load with torch.save — one loads, the other saves.

Example of loading a GPU model on CPU correctly:

python
import torch

# Correct way to load GPU saved model on CPU
device = 'cpu'
model = torch.load('model_gpu.pth', map_location=device)
📊

Quick Reference

ParameterDescription
fFile path or object to load from
map_locationDevice to map loaded tensors (e.g., 'cpu', 'cuda:0')
pickle_moduleModule for unpickling, default is pickle
weights_onlyLoad only weights from checkpoint if True

Key Takeaways

Use torch.load(path) to load saved PyTorch objects like models or tensors.
Use map_location='cpu' to load GPU-trained models on a CPU machine.
Always ensure the file path is correct to avoid file not found errors.
torch.load returns the saved object; assign it to a variable to use it.
torch.load complements torch.save, which is used to save objects.