0
0
PyTorchml~5 mins

Best model saving pattern in PyTorch

Choose your learning style9 modes available
Introduction

Saving the best model helps keep the version that performs best during training. This way, you can use the best model later without retraining.

When training a model over many epochs and you want to keep the best version.
When you want to avoid losing the best model if training stops unexpectedly.
When you want to compare different saved models later.
When you want to load the best model for testing or deployment.
When you want to save disk space by only keeping the best model, not all checkpoints.
Syntax
PyTorch
if current_val_loss < best_val_loss:
    best_val_loss = current_val_loss
    torch.save(model.state_dict(), 'best_model.pth')

Use model.state_dict() to save only the model weights, which is efficient.

Compare validation loss or accuracy to decide if the current model is better.

Examples
Saves the model only when validation loss improves.
PyTorch
best_val_loss = float('inf')
for epoch in range(epochs):
    train()
    val_loss = validate()
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_model.pth')
Saves the model only when validation accuracy improves.
PyTorch
best_accuracy = 0.0
for epoch in range(epochs):
    train()
    val_accuracy = validate()
    if val_accuracy > best_accuracy:
        best_accuracy = val_accuracy
        torch.save(model.state_dict(), 'best_model.pth')
Sample Model

This code trains a simple model and saves the best model based on validation loss. It then loads the best model and shows predictions for the first 5 samples.

PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Simple dataset
X = torch.randn(100, 10)
y = (X.sum(dim=1) > 0).long()

train_ds = TensorDataset(X, y)
train_dl = DataLoader(train_ds, batch_size=10)

# Simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 2)
    def forward(self, x):
        return self.linear(x)

model = SimpleModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)

best_val_loss = float('inf')

# Dummy validation function
def validate():
    model.eval()
    with torch.no_grad():
        outputs = model(X)
        loss = criterion(outputs, y)
    model.train()
    return loss.item()

for epoch in range(5):
    for xb, yb in train_dl:
        optimizer.zero_grad()
        preds = model(xb)
        loss = criterion(preds, yb)
        loss.backward()
        optimizer.step()
    val_loss = validate()
    print(f'Epoch {epoch+1}, Validation Loss: {val_loss:.4f}')
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_model.pth')
        print('Best model saved.')

# Load best model to check
best_model = SimpleModel()
best_model.load_state_dict(torch.load('best_model.pth'))
best_model.eval()
with torch.no_grad():
    preds = best_model(X)
    predicted_classes = preds.argmax(dim=1)
print('Predicted classes for first 5 samples:', predicted_classes[:5].tolist())
OutputSuccess
Important Notes

Always save the model weights, not the entire model object, for better compatibility.

Use validation metrics to decide when to save the model, not training metrics.

Remember to set the model to evaluation mode (model.eval()) when validating or testing.

Summary

Save the model only when it improves on validation data.

Use torch.save(model.state_dict(), filename) to save weights.

Load the best model later with model.load_state_dict(torch.load(filename)).