Checkpointing saves your model and training state so you don't lose progress if something stops your training.
Why checkpointing preserves progress in PyTorch
Start learning this pattern below
Jump into concepts and practice - no test required
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss_value,
}, PATH)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']Use model.state_dict() to save model weights.
Use optimizer.state_dict() to save optimizer state for continuing training.
torch.save(model.state_dict(), 'model.pth')torch.save({
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch
}, 'checkpoint.pth')checkpoint = torch.load('checkpoint.pth') model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) epoch = checkpoint['epoch']
This program trains a simple model for 3 epochs, saves a checkpoint, then loads it and continues training for 2 more epochs. You see the loss decrease over all 5 epochs without losing progress.
import torch import torch.nn as nn import torch.optim as optim # Simple model class SimpleModel(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(1, 1) def forward(self, x): return self.linear(x) model = SimpleModel() optimizer = optim.SGD(model.parameters(), lr=0.1) criterion = nn.MSELoss() # Dummy data x = torch.tensor([[1.0], [2.0], [3.0], [4.0]]) y = torch.tensor([[2.0], [4.0], [6.0], [8.0]]) # Train for 3 epochs and save checkpoint for epoch in range(3): optimizer.zero_grad() outputs = model(x) loss = criterion(outputs, y) loss.backward() optimizer.step() print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}") # Save checkpoint torch.save({ 'epoch': 3, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss.item() }, 'checkpoint.pth') # Load checkpoint and continue training for 2 more epochs checkpoint = torch.load('checkpoint.pth') model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) start_epoch = checkpoint['epoch'] for epoch in range(start_epoch, start_epoch + 2): optimizer.zero_grad() outputs = model(x) loss = criterion(outputs, y) loss.backward() optimizer.step() print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")
Always save both model and optimizer states to resume training properly.
Checkpoints let you avoid losing hours of training if interrupted.
You can also save checkpoints periodically during training to keep backups.
Checkpointing saves your training progress so you can stop and continue later.
It saves model weights, optimizer state, and other info like epoch number.
This helps avoid losing work and makes training more flexible and safe.
Practice
Solution
Step 1: Understand checkpointing purpose
Checkpointing saves the model's current state including weights and optimizer info.Step 2: Connect checkpointing to training progress
This allows training to stop and resume later without losing progress.Final Answer:
To save the model's current state so training can resume later without loss -> Option AQuick Check:
Checkpointing = Save progress [OK]
- Thinking checkpointing speeds up training
- Confusing checkpointing with data reduction
- Assuming checkpointing tunes hyperparameters
Solution
Step 1: Identify saving function
torch.save() is used to save objects like model weights to a file.Step 2: Check correct usage for saving model state
model.state_dict() returns model weights; saving it with torch.save() is correct.Final Answer:
torch.save(model.state_dict(), 'checkpoint.pth') -> Option BQuick Check:
Save model weights = torch.save(state_dict) [OK]
- Using torch.load instead of torch.save to save
- Trying to save optimizer with wrong method
- Confusing load_state_dict with saving
model = MyModel()
optimizer = torch.optim.Adam(model.parameters())
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state'])
optimizer.load_state_dict(checkpoint['optimizer_state'])
epoch = checkpoint['epoch']
print(epoch)Solution
Step 1: Understand checkpoint contents
The checkpoint dictionary contains keys 'model_state', 'optimizer_state', and 'epoch'.Step 2: Identify printed value
Variable 'epoch' is assigned checkpoint['epoch'], so print(epoch) outputs the saved epoch number.Final Answer:
The epoch number saved in the checkpoint -> Option DQuick Check:
Print epoch from checkpoint = epoch number [OK]
- Thinking print shows model parameters count
- Confusing optimizer state with epoch
- Assuming missing keys cause error here
RuntimeError: Error(s) in loading state_dict. What is the most likely cause related to checkpointing?Solution
Step 1: Understand error meaning
Loading state_dict errors usually happen if model layers differ from saved checkpoint.Step 2: Connect error to checkpoint cause
If model architecture changed after saving, weights won't match, causing this error.Final Answer:
The model architecture changed after saving the checkpoint -> Option CQuick Check:
State_dict error = architecture mismatch [OK]
- Confusing save/load functions causing error
- Assuming missing optimizer state causes this error
- Blaming training data changes for state_dict error
Solution
Step 1: Identify what preserves full training state
Saving model weights, optimizer state, and epoch number allows full resume.Step 2: Compare options
Only saving model weights misses optimizer info; saving optimizer and epoch without model is incomplete; saving data batch doesn't preserve progress.Final Answer:
Save a dictionary with model.state_dict(), optimizer.state_dict(), and current epoch number -> Option AQuick Check:
Checkpoint = model + optimizer + epoch [OK]
- Saving only model weights loses optimizer progress
- Ignoring epoch number causes restart from zero
- Saving training data batch does not preserve model state
