0
0
PyTorchml~5 mins

Why checkpointing preserves progress in PyTorch

Choose your learning style9 modes available
Introduction

Checkpointing saves your model and training state so you don't lose progress if something stops your training.

When training a model takes a long time and you want to pause and continue later.
If your computer might shut down or lose power during training.
When you want to try different training settings but keep the best model so far.
To save intermediate results and avoid starting from scratch after a crash.
Syntax
PyTorch
torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss_value,
}, PATH)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']

Use model.state_dict() to save model weights.

Use optimizer.state_dict() to save optimizer state for continuing training.

Examples
Saves only the model weights, useful for inference later.
PyTorch
torch.save(model.state_dict(), 'model.pth')
Saves model, optimizer, and epoch to resume training exactly.
PyTorch
torch.save({
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'epoch': epoch
}, 'checkpoint.pth')
Loads checkpoint to continue training from saved state.
PyTorch
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']
Sample Model

This program trains a simple model for 3 epochs, saves a checkpoint, then loads it and continues training for 2 more epochs. You see the loss decrease over all 5 epochs without losing progress.

PyTorch
import torch
import torch.nn as nn
import torch.optim as optim

# Simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(1, 1)
    def forward(self, x):
        return self.linear(x)

model = SimpleModel()
optimizer = optim.SGD(model.parameters(), lr=0.1)
criterion = nn.MSELoss()

# Dummy data
x = torch.tensor([[1.0], [2.0], [3.0], [4.0]])
y = torch.tensor([[2.0], [4.0], [6.0], [8.0]])

# Train for 3 epochs and save checkpoint
for epoch in range(3):
    optimizer.zero_grad()
    outputs = model(x)
    loss = criterion(outputs, y)
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

# Save checkpoint
torch.save({
    'epoch': 3,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss.item()
}, 'checkpoint.pth')

# Load checkpoint and continue training for 2 more epochs
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']

for epoch in range(start_epoch, start_epoch + 2):
    optimizer.zero_grad()
    outputs = model(x)
    loss = criterion(outputs, y)
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")
OutputSuccess
Important Notes

Always save both model and optimizer states to resume training properly.

Checkpoints let you avoid losing hours of training if interrupted.

You can also save checkpoints periodically during training to keep backups.

Summary

Checkpointing saves your training progress so you can stop and continue later.

It saves model weights, optimizer state, and other info like epoch number.

This helps avoid losing work and makes training more flexible and safe.