Saving a checkpoint with the optimizer state lets you pause and continue training later without losing progress.
Checkpoint with optimizer state in PyTorch
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch,
'loss': loss_value
}, PATH)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss_value = checkpoint['loss']Use model.state_dict() and optimizer.state_dict() to get their states.
Loading optimizer state restores learning rates and momentum for smooth training continuation.
checkpoint.pth.torch.save({'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict()}, 'checkpoint.pth')checkpoint = torch.load('checkpoint.pth') model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
torch.save({'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'epoch': current_epoch}, 'checkpoint.pth')This code trains a simple model for one step, saves the model and optimizer states along with epoch and loss, then loads them back and prints the saved epoch and loss.
import torch import torch.nn as nn import torch.optim as optim # Simple model class SimpleNet(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(2, 1) def forward(self, x): return self.linear(x) # Create model and optimizer model = SimpleNet() optimizer = optim.SGD(model.parameters(), lr=0.1) # Dummy data inputs = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) targets = torch.tensor([[1.0], [2.0]]) # Loss function criterion = nn.MSELoss() # Training step model.train() optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.step() # Save checkpoint checkpoint_path = 'checkpoint.pth' torch.save({ 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'epoch': 1, 'loss': loss.item() }, checkpoint_path) # Load checkpoint checkpoint = torch.load(checkpoint_path) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) loaded_epoch = checkpoint['epoch'] loaded_loss = checkpoint['loss'] print(f"Loaded epoch: {loaded_epoch}") print(f"Loaded loss: {loaded_loss:.4f}")
Always save both model and optimizer states to continue training smoothly.
Include extra info like epoch and loss to track training progress.
Use torch.load(PATH, map_location=torch.device('cpu')) if loading on CPU from GPU-trained model.
Checkpoint saves model and optimizer states to pause and resume training.
Loading optimizer state restores training settings like learning rate.
Include epoch and loss in checkpoint to track training progress.