How to Use Early Stopping in PyTorch for Better Training
In PyTorch,
early stopping is implemented by monitoring a validation metric during training and stopping when it stops improving. You can create a custom EarlyStopping class that saves the best model and stops training after a set patience period without improvement.Syntax
Early stopping in PyTorch is not built-in but can be implemented by creating a class that tracks validation loss or accuracy. The key parts are:
patience: how many epochs to wait for improvement before stoppingdelta: minimum change to qualify as improvementbest_score: best metric value seen so farcounter: counts epochs without improvementsave_checkpoint(): saves the model when improvement occurs
python
import torch class EarlyStopping: def __init__(self, patience=5, delta=0): self.patience = patience self.delta = delta self.best_score = None self.counter = 0 self.early_stop = False def __call__(self, val_loss, model): score = -val_loss if self.best_score is None: self.best_score = score self.save_checkpoint(model) elif score < self.best_score + self.delta: self.counter += 1 if self.counter >= self.patience: self.early_stop = True else: self.best_score = score self.save_checkpoint(model) self.counter = 0 def save_checkpoint(self, model): torch.save(model.state_dict(), 'checkpoint.pt')
Example
This example shows how to use the EarlyStopping class during training a simple neural network on dummy data. It stops training if validation loss does not improve for 3 epochs.
python
import torch import torch.nn as nn import torch.optim as optim # Simple model class Net(nn.Module): def __init__(self): super().__init__() self.fc = nn.Linear(10, 1) def forward(self, x): return self.fc(x) # EarlyStopping class from previous section class EarlyStopping: def __init__(self, patience=3, delta=0): self.patience = patience self.delta = delta self.best_score = None self.counter = 0 self.early_stop = False def __call__(self, val_loss, model): score = -val_loss if self.best_score is None: self.best_score = score self.save_checkpoint(model) elif score < self.best_score + self.delta: self.counter += 1 if self.counter >= self.patience: self.early_stop = True else: self.best_score = score self.save_checkpoint(model) self.counter = 0 def save_checkpoint(self, model): torch.save(model.state_dict(), 'checkpoint.pt') # Dummy data x_train = torch.randn(100, 10) y_train = torch.randn(100, 1) x_val = torch.randn(20, 10) y_val = torch.randn(20, 1) model = Net() criterion = nn.MSELoss() optimizer = optim.SGD(model.parameters(), lr=0.01) early_stopping = EarlyStopping(patience=3, delta=0.001) for epoch in range(50): model.train() optimizer.zero_grad() output = model(x_train) loss = criterion(output, y_train) loss.backward() optimizer.step() model.eval() with torch.no_grad(): val_output = model(x_val) val_loss = criterion(val_output, y_val).item() print(f'Epoch {epoch+1}, Val Loss: {val_loss:.4f}') early_stopping(val_loss, model) if early_stopping.early_stop: print('Early stopping triggered') break # Load best model model.load_state_dict(torch.load('checkpoint.pt'))
Output
Epoch 1, Val Loss: 1.1234
Epoch 2, Val Loss: 1.0456
Epoch 3, Val Loss: 0.9876
Epoch 4, Val Loss: 0.9801
Epoch 5, Val Loss: 0.9799
Epoch 6, Val Loss: 0.9800
Epoch 7, Val Loss: 0.9802
Early stopping triggered
Common Pitfalls
- Not resetting the counter: Forgetting to reset the patience counter when improvement occurs causes premature stopping.
- Monitoring wrong metric: Early stopping should monitor validation loss or accuracy, not training loss.
- Saving model incorrectly: Not saving the best model checkpoint means you might keep a worse model after stopping.
- Too small patience: Setting patience too low can stop training before the model learns well.
python
class EarlyStoppingWrong: def __init__(self, patience=3): self.patience = patience self.counter = 0 self.best_score = None self.early_stop = False def __call__(self, val_loss, model): score = -val_loss if self.best_score is None or score > self.best_score: self.best_score = score # Missing reset of counter here causes issues else: self.counter += 1 if self.counter >= self.patience: self.early_stop = True # Correct way resets counter when improvement occurs class EarlyStoppingRight: def __init__(self, patience=3): self.patience = patience self.counter = 0 self.best_score = None self.early_stop = False def __call__(self, val_loss, model): score = -val_loss if self.best_score is None or score > self.best_score: self.best_score = score self.counter = 0 # Reset counter here else: self.counter += 1 if self.counter >= self.patience: self.early_stop = True
Quick Reference
Early Stopping Tips:
- Monitor validation loss or accuracy, not training loss.
- Set patience to allow some epochs without improvement.
- Save the best model checkpoint to restore after stopping.
- Use a small delta to ignore tiny fluctuations.
- Load the saved model after training stops.
Key Takeaways
Implement early stopping by tracking validation metric and stopping after no improvement for set patience.
Always save the best model checkpoint during training to restore later.
Reset the patience counter when validation improves to avoid premature stopping.
Monitor validation loss or accuracy, not training loss, for early stopping decisions.
Choose patience and delta values carefully to balance training time and model quality.