0
0
PyTorchml~5 mins

Checkpoint with optimizer state in PyTorch

Choose your learning style9 modes available
Introduction

Saving a checkpoint with the optimizer state lets you pause and continue training later without losing progress.

You want to stop training and resume later without starting over.
You want to save your model and optimizer to recover from crashes.
You want to try different training settings starting from the same point.
You want to share your trained model and optimizer state with others.
Syntax
PyTorch
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'epoch': epoch,
    'loss': loss_value
}, PATH)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss_value = checkpoint['loss']

Use model.state_dict() and optimizer.state_dict() to get their states.

Loading optimizer state restores learning rates and momentum for smooth training continuation.

Examples
Saves model and optimizer states to a file named checkpoint.pth.
PyTorch
torch.save({'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict()}, 'checkpoint.pth')
Loads the saved states back into model and optimizer.
PyTorch
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
Also saves the current epoch number to resume training from the right place.
PyTorch
torch.save({'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'epoch': current_epoch}, 'checkpoint.pth')
Sample Model

This code trains a simple model for one step, saves the model and optimizer states along with epoch and loss, then loads them back and prints the saved epoch and loss.

PyTorch
import torch
import torch.nn as nn
import torch.optim as optim

# Simple model
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(2, 1)
    def forward(self, x):
        return self.linear(x)

# Create model and optimizer
model = SimpleNet()
optimizer = optim.SGD(model.parameters(), lr=0.1)

# Dummy data
inputs = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
targets = torch.tensor([[1.0], [2.0]])

# Loss function
criterion = nn.MSELoss()

# Training step
model.train()
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()

# Save checkpoint
checkpoint_path = 'checkpoint.pth'
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'epoch': 1,
    'loss': loss.item()
}, checkpoint_path)

# Load checkpoint
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
loaded_epoch = checkpoint['epoch']
loaded_loss = checkpoint['loss']

print(f"Loaded epoch: {loaded_epoch}")
print(f"Loaded loss: {loaded_loss:.4f}")
OutputSuccess
Important Notes

Always save both model and optimizer states to continue training smoothly.

Include extra info like epoch and loss to track training progress.

Use torch.load(PATH, map_location=torch.device('cpu')) if loading on CPU from GPU-trained model.

Summary

Checkpoint saves model and optimizer states to pause and resume training.

Loading optimizer state restores training settings like learning rate.

Include epoch and loss in checkpoint to track training progress.