Checkpointing saves your model and training state so you don't lose progress if something stops your training.
0
0
Why checkpointing preserves progress in PyTorch
Introduction
When training a model takes a long time and you want to pause and continue later.
If your computer might shut down or lose power during training.
When you want to try different training settings but keep the best model so far.
To save intermediate results and avoid starting from scratch after a crash.
Syntax
PyTorch
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss_value,
}, PATH)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']Use model.state_dict() to save model weights.
Use optimizer.state_dict() to save optimizer state for continuing training.
Examples
Saves only the model weights, useful for inference later.
PyTorch
torch.save(model.state_dict(), 'model.pth')Saves model, optimizer, and epoch to resume training exactly.
PyTorch
torch.save({
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch
}, 'checkpoint.pth')Loads checkpoint to continue training from saved state.
PyTorch
checkpoint = torch.load('checkpoint.pth') model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) epoch = checkpoint['epoch']
Sample Model
This program trains a simple model for 3 epochs, saves a checkpoint, then loads it and continues training for 2 more epochs. You see the loss decrease over all 5 epochs without losing progress.
PyTorch
import torch import torch.nn as nn import torch.optim as optim # Simple model class SimpleModel(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(1, 1) def forward(self, x): return self.linear(x) model = SimpleModel() optimizer = optim.SGD(model.parameters(), lr=0.1) criterion = nn.MSELoss() # Dummy data x = torch.tensor([[1.0], [2.0], [3.0], [4.0]]) y = torch.tensor([[2.0], [4.0], [6.0], [8.0]]) # Train for 3 epochs and save checkpoint for epoch in range(3): optimizer.zero_grad() outputs = model(x) loss = criterion(outputs, y) loss.backward() optimizer.step() print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}") # Save checkpoint torch.save({ 'epoch': 3, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss.item() }, 'checkpoint.pth') # Load checkpoint and continue training for 2 more epochs checkpoint = torch.load('checkpoint.pth') model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) start_epoch = checkpoint['epoch'] for epoch in range(start_epoch, start_epoch + 2): optimizer.zero_grad() outputs = model(x) loss = criterion(outputs, y) loss.backward() optimizer.step() print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")
OutputSuccess
Important Notes
Always save both model and optimizer states to resume training properly.
Checkpoints let you avoid losing hours of training if interrupted.
You can also save checkpoints periodically during training to keep backups.
Summary
Checkpointing saves your training progress so you can stop and continue later.
It saves model weights, optimizer state, and other info like epoch number.
This helps avoid losing work and makes training more flexible and safe.