0
0
PyTorchml~20 mins

Why checkpointing preserves progress in PyTorch - Experiment to Prove It

Choose your learning style9 modes available
Experiment - Why checkpointing preserves progress
Problem:You are training a neural network on a dataset, but the training takes a long time. If the training process stops unexpectedly, you lose all progress and must start over.
Current Metrics:Training accuracy: 85%, Validation accuracy: 80%, Training loss: 0.5, Validation loss: 0.6
Issue:No checkpointing is used, so if training is interrupted, all progress is lost and training must restart from scratch.
Your Task
Implement checkpointing to save the model and optimizer state during training so that training can resume from the last saved point without losing progress.
Use PyTorch's native checkpointing methods.
Save checkpoints every 2 epochs.
Ensure that after restarting from a checkpoint, training continues correctly.
Hint 1
Hint 2
Hint 3
Solution
PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Simple dataset
X = torch.randn(1000, 10)
y = (X.sum(dim=1) > 0).long()
dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Simple model
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(10, 2)
    def forward(self, x):
        return self.fc(x)

model = SimpleNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)

# Checkpoint path
checkpoint_path = 'checkpoint.pth'

# Function to save checkpoint
def save_checkpoint(epoch, model, optimizer):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, checkpoint_path)

# Function to load checkpoint
def load_checkpoint(model, optimizer):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return checkpoint['epoch']

# Try to load checkpoint
start_epoch = 0
try:
    start_epoch = load_checkpoint(model, optimizer) + 1
    print(f'Resuming training from epoch {start_epoch}')
except FileNotFoundError:
    print('No checkpoint found, starting fresh training')

# Training loop
num_epochs = 10
for epoch in range(start_epoch, num_epochs):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    for inputs, labels in dataloader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        correct += predicted.eq(labels).sum().item()
        total += labels.size(0)

    avg_loss = total_loss / total
    accuracy = correct / total * 100
    print(f'Epoch {epoch}: Loss={avg_loss:.4f}, Accuracy={accuracy:.2f}%')

    # Save checkpoint every 2 epochs
    if (epoch + 1) % 2 == 0:
        save_checkpoint(epoch, model, optimizer)
        print(f'Checkpoint saved at epoch {epoch + 1}')
Added functions to save and load checkpoints using torch.save and torch.load.
Saved model state, optimizer state, and current epoch in checkpoint.
Modified training loop to load checkpoint if available and resume training.
Saved checkpoint every 2 epochs to preserve progress.
Fixed checkpoint saved epoch print statement to show correct epoch number.
Results Interpretation

Before checkpointing: If training stops, all progress is lost and training restarts from epoch 0.

After checkpointing: Training resumes from the last saved epoch, preserving progress and saving time.

Checkpointing saves the model and optimizer states during training. This allows training to continue from the last saved point after interruptions, preventing loss of progress and saving time.
Bonus Experiment
Try implementing checkpointing that also saves the best model based on validation accuracy.
💡 Hint
Track validation accuracy each epoch and save a separate checkpoint only when validation accuracy improves.