Imagine you are training a neural network that takes hours to complete. You want to save your progress so you can continue later without starting over. Why does saving a checkpoint help preserve your training progress?
Think about what information is needed to continue training without losing progress.
Checkpointing saves the model's weights and optimizer state. This means when you load the checkpoint, the model continues training from the exact point it stopped, preserving all learned information and optimizer momentum.
Consider this PyTorch code snippet that saves and loads a checkpoint during training:
import torch
import torch.nn as nn
import torch.optim as optim
model = nn.Linear(2, 1)
optimizer = optim.SGD(model.parameters(), lr=0.1)
# Simulate training step
for param in model.parameters():
param.data.fill_(1.0)
# Save checkpoint
checkpoint = {'model_state': model.state_dict(), 'optimizer_state': optimizer.state_dict()}
torch.save(checkpoint, 'checkpoint.pth')
# Reset model weights to zero
for param in model.parameters():
param.data.fill_(0.0)
# Load checkpoint
loaded = torch.load('checkpoint.pth')
model.load_state_dict(loaded['model_state'])
# What is the value of model.weight after loading?
print(model.weight)Loading the checkpoint restores the saved weights exactly.
After loading the checkpoint, the model's weights are restored to the saved values (all ones). The reset to zero is overwritten by loading the checkpoint.
When saving a checkpoint in PyTorch, which hyperparameter related to the optimizer must be saved to correctly resume training?
Think about what the optimizer needs to continue updating weights properly.
The optimizer state includes hyperparameters like learning rate and momentum. Saving and restoring these ensures the optimizer continues updating weights consistently after loading a checkpoint.
Look at this PyTorch code snippet that tries to load a checkpoint but raises an error:
model = nn.Linear(2, 1)
optimizer = optim.SGD(model.parameters(), lr=0.1)
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state'])
optimizer.load_state_dict(checkpoint['optimizer']) # Error hereWhat is the cause of the error?
Check the exact keys used when saving the checkpoint.
The checkpoint dictionary uses the key 'optimizer_state' for the optimizer's state dict, but the code tries to access 'optimizer', causing a KeyError.
You are training a very large neural network that takes days to train. You want to save checkpoints efficiently without losing progress and minimize storage. Which checkpointing strategy is best?
Consider storage size and ability to resume training exactly.
Saving only the model's and optimizer's state_dicts is efficient and sufficient to resume training exactly. Saving the entire model object is larger and less flexible. Saving data batches does not preserve model progress. Saving only after training loses all progress if interrupted.