This code trains a simple linear model to fit y=2x with noise. It uses early stopping to stop training if validation loss does not improve for 5 epochs by at least 0.001.
import torch
import torch.nn as nn
import torch.optim as optim
# Simple model
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(1, 1)
def forward(self, x):
return self.linear(x)
# EarlyStopping class
class EarlyStopping:
def __init__(self, patience=3, min_delta=0):
self.patience = patience
self.min_delta = min_delta
self.counter = 0
self.best_loss = None
self.early_stop = False
def __call__(self, val_loss):
if self.best_loss is None:
self.best_loss = val_loss
elif val_loss > self.best_loss - self.min_delta:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_loss = val_loss
self.counter = 0
# Data: y = 2x + noise
x_train = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y_train = 2 * x_train + 0.1 * torch.randn(x_train.size())
x_val = torch.unsqueeze(torch.linspace(-1, 1, 20), dim=1)
y_val = 2 * x_val + 0.1 * torch.randn(x_val.size())
model = SimpleNet()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)
early_stopping = EarlyStopping(patience=5, min_delta=0.001)
for epoch in range(100):
model.train()
optimizer.zero_grad()
outputs = model(x_train)
loss = criterion(outputs, y_train)
loss.backward()
optimizer.step()
model.eval()
with torch.no_grad():
val_outputs = model(x_val)
val_loss = criterion(val_outputs, y_val)
print(f'Epoch {epoch+1}, Training Loss: {loss.item():.4f}, Validation Loss: {val_loss.item():.4f}')
early_stopping(val_loss.item())
if early_stopping.early_stop:
print(f'Early stopping at epoch {epoch+1}')
break