Bird
Raised Fist0
PyTorchml~5 mins

Best model saving pattern in PyTorch

Choose your learning style10 modes available

Start learning this pattern below

Jump into concepts and practice - no test required

or
Recommended
Test this pattern10 questions across easy, medium, and hard to know if this pattern is strong
Introduction

Saving the best model helps keep the version that performs best during training. This way, you can use the best model later without retraining.

When training a model over many epochs and you want to keep the best version.
When you want to avoid losing the best model if training stops unexpectedly.
When you want to compare different saved models later.
When you want to load the best model for testing or deployment.
When you want to save disk space by only keeping the best model, not all checkpoints.
Syntax
PyTorch
if current_val_loss < best_val_loss:
    best_val_loss = current_val_loss
    torch.save(model.state_dict(), 'best_model.pth')

Use model.state_dict() to save only the model weights, which is efficient.

Compare validation loss or accuracy to decide if the current model is better.

Examples
Saves the model only when validation loss improves.
PyTorch
best_val_loss = float('inf')
for epoch in range(epochs):
    train()
    val_loss = validate()
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_model.pth')
Saves the model only when validation accuracy improves.
PyTorch
best_accuracy = 0.0
for epoch in range(epochs):
    train()
    val_accuracy = validate()
    if val_accuracy > best_accuracy:
        best_accuracy = val_accuracy
        torch.save(model.state_dict(), 'best_model.pth')
Sample Model

This code trains a simple model and saves the best model based on validation loss. It then loads the best model and shows predictions for the first 5 samples.

PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Simple dataset
X = torch.randn(100, 10)
y = (X.sum(dim=1) > 0).long()

train_ds = TensorDataset(X, y)
train_dl = DataLoader(train_ds, batch_size=10)

# Simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 2)
    def forward(self, x):
        return self.linear(x)

model = SimpleModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)

best_val_loss = float('inf')

# Dummy validation function
def validate():
    model.eval()
    with torch.no_grad():
        outputs = model(X)
        loss = criterion(outputs, y)
    model.train()
    return loss.item()

for epoch in range(5):
    for xb, yb in train_dl:
        optimizer.zero_grad()
        preds = model(xb)
        loss = criterion(preds, yb)
        loss.backward()
        optimizer.step()
    val_loss = validate()
    print(f'Epoch {epoch+1}, Validation Loss: {val_loss:.4f}')
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_model.pth')
        print('Best model saved.')

# Load best model to check
best_model = SimpleModel()
best_model.load_state_dict(torch.load('best_model.pth'))
best_model.eval()
with torch.no_grad():
    preds = best_model(X)
    predicted_classes = preds.argmax(dim=1)
print('Predicted classes for first 5 samples:', predicted_classes[:5].tolist())
OutputSuccess
Important Notes

Always save the model weights, not the entire model object, for better compatibility.

Use validation metrics to decide when to save the model, not training metrics.

Remember to set the model to evaluation mode (model.eval()) when validating or testing.

Summary

Save the model only when it improves on validation data.

Use torch.save(model.state_dict(), filename) to save weights.

Load the best model later with model.load_state_dict(torch.load(filename)).

Practice

(1/5)
1. What is the best practice for saving a PyTorch model during training?
easy
A. Save the model only at the start of training.
B. Save the model only when it improves on validation data.
C. Save the model after every training batch.
D. Save the model only if the training loss increases.

Solution

  1. Step 1: Understand model saving timing

    Saving the model only when validation improves ensures you keep the best version, avoiding unnecessary saves.
  2. Step 2: Compare other options

    Saving every batch wastes space; saving at start or on loss increase is not useful for best model.
  3. Final Answer:

    Save the model only when it improves on validation data. -> Option B
  4. Quick Check:

    Save best validation model = C [OK]
Hint: Save model only on validation improvement to keep best [OK]
Common Mistakes:
  • Saving model too frequently wastes storage
  • Saving only at start misses improvements
  • Saving on training loss increase is wrong
2. Which of the following is the correct PyTorch code to save only the model weights?
easy
A. torch.save(model.state_dict(), 'model.pth')
B. torch.save(model, 'model.pth')
C. model.save('model.pth')
D. model.state_dict().save('model.pth')

Solution

  1. Step 1: Identify correct saving method

    PyTorch saves weights using torch.save(model.state_dict(), filename).
  2. Step 2: Check other options

    Saving the whole model (torch.save(model, 'model.pth')) is possible but less flexible; options C and D are invalid syntax.
  3. Final Answer:

    torch.save(model.state_dict(), 'model.pth') -> Option A
  4. Quick Check:

    Save weights with state_dict() = A [OK]
Hint: Use torch.save(model.state_dict(), filename) to save weights [OK]
Common Mistakes:
  • Trying to save model directly without state_dict
  • Using non-existent save methods on model
  • Confusing saving weights vs full model
3. Given this code snippet, what will be printed?
import torch
import torch.nn as nn

model = nn.Linear(2, 1)
torch.save(model.state_dict(), 'best.pth')
new_model = nn.Linear(2, 1)
new_model.load_state_dict(torch.load('best.pth'))
print(new_model.weight.shape)
medium
A. torch.Size([1, 2])
B. torch.Size([2, 1])
C. torch.Size([1, 1])
D. Error: shape mismatch

Solution

  1. Step 1: Understand model architecture

    nn.Linear(2,1) creates weights of shape [1, 2] (output features, input features).
  2. Step 2: Loading weights into new model

    Loading saved weights into identical model keeps weight shape same.
  3. Final Answer:

    torch.Size([1, 2]) -> Option A
  4. Quick Check:

    Linear(2,1) weight shape = [1, 2] [OK]
Hint: Linear layer weights shape = (out_features, in_features) [OK]
Common Mistakes:
  • Confusing input/output dimensions order
  • Expecting error when loading identical model
  • Misreading weight shape as (2,1)
4. What is wrong with this code snippet for saving the best model?
if val_loss < best_loss:
    best_loss = val_loss
    torch.save(model, 'best_model.pth')
medium
A. There is no condition to check validation loss.
B. It should save model.state_dict() instead of model.
C. It does not update best_loss correctly.
D. It saves the entire model, which is less flexible than saving state_dict.

Solution

  1. Step 1: Analyze saving method

    Saving entire model works but is less flexible and may cause issues when loading on different devices or PyTorch versions.
  2. Step 2: Compare with best practice

    Best practice is saving model.state_dict() for portability and smaller files.
  3. Final Answer:

    It saves the entire model, which is less flexible than saving state_dict. -> Option D
  4. Quick Check:

    Save state_dict() preferred over full model [OK]
Hint: Save state_dict() for flexibility, not full model [OK]
Common Mistakes:
  • Saving full model without state_dict
  • Ignoring portability issues
  • Assuming full model save is always best
5. You want to save the best model during training based on validation accuracy. Which code snippet correctly implements this pattern?
best_acc = 0.0
for epoch in range(epochs):
    train()
    val_acc = validate()
    # Save best model here
    ???
hard
A. if val_acc < best_acc: best_acc = val_acc torch.save(model.state_dict(), 'best_model.pth')
B. if val_acc == best_acc: torch.save(model.state_dict(), 'best_model.pth')
C. if val_acc > best_acc: best_acc = val_acc torch.save(model.state_dict(), 'best_model.pth')
D. torch.save(model.state_dict(), 'best_model.pth') # save every epoch

Solution

  1. Step 1: Identify saving condition

    We save model only if validation accuracy improves (val_acc > best_acc).
  2. Step 2: Update best accuracy and save weights

    Update best_acc and save model.state_dict() to keep best weights.
  3. Final Answer:

    if val_acc > best_acc: best_acc = val_acc torch.save(model.state_dict(), 'best_model.pth') -> Option C
  4. Quick Check:

    Save on val_acc improvement = B [OK]
Hint: Save model only if validation accuracy improves [OK]
Common Mistakes:
  • Saving when accuracy decreases
  • Saving every epoch wastes space
  • Not updating best accuracy variable