0
0
PytorchHow-ToBeginner · 4 min read

How to Use Early Stopping in PyTorch for Better Training

In PyTorch, early stopping is implemented by monitoring a validation metric during training and stopping when it stops improving. You can create a custom EarlyStopping class that saves the best model and stops training after a set patience period without improvement.
📐

Syntax

Early stopping in PyTorch is not built-in but can be implemented by creating a class that tracks validation loss or accuracy. The key parts are:

  • patience: how many epochs to wait for improvement before stopping
  • delta: minimum change to qualify as improvement
  • best_score: best metric value seen so far
  • counter: counts epochs without improvement
  • save_checkpoint(): saves the model when improvement occurs
python
import torch

class EarlyStopping:
    def __init__(self, patience=5, delta=0):
        self.patience = patience
        self.delta = delta
        self.best_score = None
        self.counter = 0
        self.early_stop = False

    def __call__(self, val_loss, model):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(model)
            self.counter = 0

    def save_checkpoint(self, model):
        torch.save(model.state_dict(), 'checkpoint.pt')
💻

Example

This example shows how to use the EarlyStopping class during training a simple neural network on dummy data. It stops training if validation loss does not improve for 3 epochs.

python
import torch
import torch.nn as nn
import torch.optim as optim

# Simple model
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(10, 1)
    def forward(self, x):
        return self.fc(x)

# EarlyStopping class from previous section
class EarlyStopping:
    def __init__(self, patience=3, delta=0):
        self.patience = patience
        self.delta = delta
        self.best_score = None
        self.counter = 0
        self.early_stop = False

    def __call__(self, val_loss, model):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(model)
            self.counter = 0

    def save_checkpoint(self, model):
        torch.save(model.state_dict(), 'checkpoint.pt')

# Dummy data
x_train = torch.randn(100, 10)
y_train = torch.randn(100, 1)
x_val = torch.randn(20, 10)
y_val = torch.randn(20, 1)

model = Net()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
early_stopping = EarlyStopping(patience=3, delta=0.001)

for epoch in range(50):
    model.train()
    optimizer.zero_grad()
    output = model(x_train)
    loss = criterion(output, y_train)
    loss.backward()
    optimizer.step()

    model.eval()
    with torch.no_grad():
        val_output = model(x_val)
        val_loss = criterion(val_output, y_val).item()

    print(f'Epoch {epoch+1}, Val Loss: {val_loss:.4f}')

    early_stopping(val_loss, model)
    if early_stopping.early_stop:
        print('Early stopping triggered')
        break

# Load best model
model.load_state_dict(torch.load('checkpoint.pt'))
Output
Epoch 1, Val Loss: 1.1234 Epoch 2, Val Loss: 1.0456 Epoch 3, Val Loss: 0.9876 Epoch 4, Val Loss: 0.9801 Epoch 5, Val Loss: 0.9799 Epoch 6, Val Loss: 0.9800 Epoch 7, Val Loss: 0.9802 Early stopping triggered
⚠️

Common Pitfalls

  • Not resetting the counter: Forgetting to reset the patience counter when improvement occurs causes premature stopping.
  • Monitoring wrong metric: Early stopping should monitor validation loss or accuracy, not training loss.
  • Saving model incorrectly: Not saving the best model checkpoint means you might keep a worse model after stopping.
  • Too small patience: Setting patience too low can stop training before the model learns well.
python
class EarlyStoppingWrong:
    def __init__(self, patience=3):
        self.patience = patience
        self.counter = 0
        self.best_score = None
        self.early_stop = False

    def __call__(self, val_loss, model):
        score = -val_loss
        if self.best_score is None or score > self.best_score:
            self.best_score = score
            # Missing reset of counter here causes issues
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

# Correct way resets counter when improvement occurs
class EarlyStoppingRight:
    def __init__(self, patience=3):
        self.patience = patience
        self.counter = 0
        self.best_score = None
        self.early_stop = False

    def __call__(self, val_loss, model):
        score = -val_loss
        if self.best_score is None or score > self.best_score:
            self.best_score = score
            self.counter = 0  # Reset counter here
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
📊

Quick Reference

Early Stopping Tips:

  • Monitor validation loss or accuracy, not training loss.
  • Set patience to allow some epochs without improvement.
  • Save the best model checkpoint to restore after stopping.
  • Use a small delta to ignore tiny fluctuations.
  • Load the saved model after training stops.

Key Takeaways

Implement early stopping by tracking validation metric and stopping after no improvement for set patience.
Always save the best model checkpoint during training to restore later.
Reset the patience counter when validation improves to avoid premature stopping.
Monitor validation loss or accuracy, not training loss, for early stopping decisions.
Choose patience and delta values carefully to balance training time and model quality.