Introduction
Early stopping helps stop training a model when it stops improving, saving time and avoiding overfitting.
Jump into concepts and practice - no test required
class EarlyStopping: def __init__(self, patience=5, min_delta=0): self.patience = patience self.min_delta = min_delta self.counter = 0 self.best_loss = None self.early_stop = False def __call__(self, val_loss): if self.best_loss is None: self.best_loss = val_loss elif val_loss > self.best_loss - self.min_delta: self.counter += 1 if self.counter >= self.patience: self.early_stop = True else: self.best_loss = val_loss self.counter = 0
early_stopping = EarlyStopping(patience=3, min_delta=0.01)
early_stopping(val_loss) if early_stopping.early_stop: print('Stop training')
import torch import torch.nn as nn import torch.optim as optim # Simple model class SimpleNet(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(1, 1) def forward(self, x): return self.linear(x) # EarlyStopping class class EarlyStopping: def __init__(self, patience=3, min_delta=0): self.patience = patience self.min_delta = min_delta self.counter = 0 self.best_loss = None self.early_stop = False def __call__(self, val_loss): if self.best_loss is None: self.best_loss = val_loss elif val_loss > self.best_loss - self.min_delta: self.counter += 1 if self.counter >= self.patience: self.early_stop = True else: self.best_loss = val_loss self.counter = 0 # Data: y = 2x + noise x_train = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) y_train = 2 * x_train + 0.1 * torch.randn(x_train.size()) x_val = torch.unsqueeze(torch.linspace(-1, 1, 20), dim=1) y_val = 2 * x_val + 0.1 * torch.randn(x_val.size()) model = SimpleNet() criterion = nn.MSELoss() optimizer = optim.SGD(model.parameters(), lr=0.1) early_stopping = EarlyStopping(patience=5, min_delta=0.001) for epoch in range(100): model.train() optimizer.zero_grad() outputs = model(x_train) loss = criterion(outputs, y_train) loss.backward() optimizer.step() model.eval() with torch.no_grad(): val_outputs = model(x_val) val_loss = criterion(val_outputs, y_val) print(f'Epoch {epoch+1}, Training Loss: {loss.item():.4f}, Validation Loss: {val_loss.item():.4f}') early_stopping(val_loss.item()) if early_stopping.early_stop: print(f'Early stopping at epoch {epoch+1}') break
early stopping in PyTorch training?early_stopping = EarlyStopping(patience=2, min_delta=0.01)
for epoch, val_loss in enumerate([0.5, 0.4, 0.42, 0.43]):
early_stopping(val_loss)
if early_stopping.early_stop:
print(f"Stop at epoch {epoch}")
breakearly_stopping = EarlyStopping(patience=3, min_delta=0.01)
for val_loss in val_losses:
if early_stopping.early_stop:
break
early_stopping(val_loss)patience and min_delta should you use?