Saving the best model helps keep the version that performs best during training. This way, you can use the best model later without retraining.
Best model saving pattern in PyTorch
Start learning this pattern below
Jump into concepts and practice - no test required
if current_val_loss < best_val_loss: best_val_loss = current_val_loss torch.save(model.state_dict(), 'best_model.pth')
Use model.state_dict() to save only the model weights, which is efficient.
Compare validation loss or accuracy to decide if the current model is better.
best_val_loss = float('inf') for epoch in range(epochs): train() val_loss = validate() if val_loss < best_val_loss: best_val_loss = val_loss torch.save(model.state_dict(), 'best_model.pth')
best_accuracy = 0.0 for epoch in range(epochs): train() val_accuracy = validate() if val_accuracy > best_accuracy: best_accuracy = val_accuracy torch.save(model.state_dict(), 'best_model.pth')
This code trains a simple model and saves the best model based on validation loss. It then loads the best model and shows predictions for the first 5 samples.
import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, TensorDataset # Simple dataset X = torch.randn(100, 10) y = (X.sum(dim=1) > 0).long() train_ds = TensorDataset(X, y) train_dl = DataLoader(train_ds, batch_size=10) # Simple model class SimpleModel(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(10, 2) def forward(self, x): return self.linear(x) model = SimpleModel() criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.1) best_val_loss = float('inf') # Dummy validation function def validate(): model.eval() with torch.no_grad(): outputs = model(X) loss = criterion(outputs, y) model.train() return loss.item() for epoch in range(5): for xb, yb in train_dl: optimizer.zero_grad() preds = model(xb) loss = criterion(preds, yb) loss.backward() optimizer.step() val_loss = validate() print(f'Epoch {epoch+1}, Validation Loss: {val_loss:.4f}') if val_loss < best_val_loss: best_val_loss = val_loss torch.save(model.state_dict(), 'best_model.pth') print('Best model saved.') # Load best model to check best_model = SimpleModel() best_model.load_state_dict(torch.load('best_model.pth')) best_model.eval() with torch.no_grad(): preds = best_model(X) predicted_classes = preds.argmax(dim=1) print('Predicted classes for first 5 samples:', predicted_classes[:5].tolist())
Always save the model weights, not the entire model object, for better compatibility.
Use validation metrics to decide when to save the model, not training metrics.
Remember to set the model to evaluation mode (model.eval()) when validating or testing.
Save the model only when it improves on validation data.
Use torch.save(model.state_dict(), filename) to save weights.
Load the best model later with model.load_state_dict(torch.load(filename)).
Practice
Solution
Step 1: Understand model saving timing
Saving the model only when validation improves ensures you keep the best version, avoiding unnecessary saves.Step 2: Compare other options
Saving every batch wastes space; saving at start or on loss increase is not useful for best model.Final Answer:
Save the model only when it improves on validation data. -> Option BQuick Check:
Save best validation model = C [OK]
- Saving model too frequently wastes storage
- Saving only at start misses improvements
- Saving on training loss increase is wrong
Solution
Step 1: Identify correct saving method
PyTorch saves weights using torch.save(model.state_dict(), filename).Step 2: Check other options
Saving the whole model (torch.save(model, 'model.pth')) is possible but less flexible; options C and D are invalid syntax.Final Answer:
torch.save(model.state_dict(), 'model.pth') -> Option AQuick Check:
Save weights with state_dict() = A [OK]
- Trying to save model directly without state_dict
- Using non-existent save methods on model
- Confusing saving weights vs full model
import torch
import torch.nn as nn
model = nn.Linear(2, 1)
torch.save(model.state_dict(), 'best.pth')
new_model = nn.Linear(2, 1)
new_model.load_state_dict(torch.load('best.pth'))
print(new_model.weight.shape)Solution
Step 1: Understand model architecture
nn.Linear(2,1) creates weights of shape [1, 2] (output features, input features).Step 2: Loading weights into new model
Loading saved weights into identical model keeps weight shape same.Final Answer:
torch.Size([1, 2]) -> Option AQuick Check:
Linear(2,1) weight shape = [1, 2] [OK]
- Confusing input/output dimensions order
- Expecting error when loading identical model
- Misreading weight shape as (2,1)
if val_loss < best_loss:
best_loss = val_loss
torch.save(model, 'best_model.pth')Solution
Step 1: Analyze saving method
Saving entire model works but is less flexible and may cause issues when loading on different devices or PyTorch versions.Step 2: Compare with best practice
Best practice is saving model.state_dict() for portability and smaller files.Final Answer:
It saves the entire model, which is less flexible than saving state_dict. -> Option DQuick Check:
Save state_dict() preferred over full model [OK]
- Saving full model without state_dict
- Ignoring portability issues
- Assuming full model save is always best
best_acc = 0.0
for epoch in range(epochs):
train()
val_acc = validate()
# Save best model here
???Solution
Step 1: Identify saving condition
We save model only if validation accuracy improves (val_acc > best_acc).Step 2: Update best accuracy and save weights
Update best_acc and save model.state_dict() to keep best weights.Final Answer:
if val_acc > best_acc: best_acc = val_acc torch.save(model.state_dict(), 'best_model.pth') -> Option CQuick Check:
Save on val_acc improvement = B [OK]
- Saving when accuracy decreases
- Saving every epoch wastes space
- Not updating best accuracy variable
