Saving the best model helps keep the version that performs best during training. This way, you can use the best model later without retraining.
Best model saving pattern in PyTorch
if current_val_loss < best_val_loss: best_val_loss = current_val_loss torch.save(model.state_dict(), 'best_model.pth')
Use model.state_dict() to save only the model weights, which is efficient.
Compare validation loss or accuracy to decide if the current model is better.
best_val_loss = float('inf') for epoch in range(epochs): train() val_loss = validate() if val_loss < best_val_loss: best_val_loss = val_loss torch.save(model.state_dict(), 'best_model.pth')
best_accuracy = 0.0 for epoch in range(epochs): train() val_accuracy = validate() if val_accuracy > best_accuracy: best_accuracy = val_accuracy torch.save(model.state_dict(), 'best_model.pth')
This code trains a simple model and saves the best model based on validation loss. It then loads the best model and shows predictions for the first 5 samples.
import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, TensorDataset # Simple dataset X = torch.randn(100, 10) y = (X.sum(dim=1) > 0).long() train_ds = TensorDataset(X, y) train_dl = DataLoader(train_ds, batch_size=10) # Simple model class SimpleModel(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(10, 2) def forward(self, x): return self.linear(x) model = SimpleModel() criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.1) best_val_loss = float('inf') # Dummy validation function def validate(): model.eval() with torch.no_grad(): outputs = model(X) loss = criterion(outputs, y) model.train() return loss.item() for epoch in range(5): for xb, yb in train_dl: optimizer.zero_grad() preds = model(xb) loss = criterion(preds, yb) loss.backward() optimizer.step() val_loss = validate() print(f'Epoch {epoch+1}, Validation Loss: {val_loss:.4f}') if val_loss < best_val_loss: best_val_loss = val_loss torch.save(model.state_dict(), 'best_model.pth') print('Best model saved.') # Load best model to check best_model = SimpleModel() best_model.load_state_dict(torch.load('best_model.pth')) best_model.eval() with torch.no_grad(): preds = best_model(X) predicted_classes = preds.argmax(dim=1) print('Predicted classes for first 5 samples:', predicted_classes[:5].tolist())
Always save the model weights, not the entire model object, for better compatibility.
Use validation metrics to decide when to save the model, not training metrics.
Remember to set the model to evaluation mode (model.eval()) when validating or testing.
Save the model only when it improves on validation data.
Use torch.save(model.state_dict(), filename) to save weights.
Load the best model later with model.load_state_dict(torch.load(filename)).