0
0
PyTorchml~20 mins

Early stopping implementation in PyTorch - ML Experiment: Train & Evaluate

Choose your learning style9 modes available
Experiment - Early stopping implementation
Problem:Train a neural network on a classification task but the model starts to overfit after some epochs.
Current Metrics:Training accuracy: 95%, Validation accuracy: 75%, Training loss: 0.15, Validation loss: 0.45
Issue:The model overfits: training accuracy is high but validation accuracy is much lower and validation loss increases after some epochs.
Your Task
Implement early stopping to stop training when validation loss stops improving, aiming to improve validation accuracy to above 80% while preventing overfitting.
Do not change the model architecture.
Do not change the optimizer or learning rate.
Only add early stopping logic to the training loop.
Hint 1
Hint 2
Hint 3
Solution
PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Sample dataset (replace with real data)
X_train = torch.randn(500, 20)
y_train = torch.randint(0, 2, (500,))
X_val = torch.randn(100, 20)
y_val = torch.randint(0, 2, (100,))

train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

# Simple model
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(20, 50)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(50, 2)

    def forward(self, x):
        x = self.relu(self.fc1(x))
        return self.fc2(x)

model = SimpleNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Early stopping parameters
patience = 3
best_val_loss = float('inf')
epochs_no_improve = 0
num_epochs = 50
best_model_state = None

for epoch in range(num_epochs):
    model.train()
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    model.eval()
    val_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            outputs = model(inputs)
            loss_val = criterion(outputs, labels)
            val_loss += loss_val.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    val_loss /= total
    val_accuracy = correct / total * 100

    print(f'Epoch {epoch+1}: Validation Loss = {val_loss:.4f}, Validation Accuracy = {val_accuracy:.2f}%')

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epochs_no_improve = 0
        best_model_state = model.state_dict()
    else:
        epochs_no_improve += 1

    if epochs_no_improve >= patience:
        print(f'Early stopping triggered after {epoch+1} epochs.')
        break

# Load best model weights
if best_model_state is not None:
    model.load_state_dict(best_model_state)

# Final evaluation on validation set
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in val_loader:
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
final_val_accuracy = correct / total * 100
print(f'Final Validation Accuracy after Early Stopping: {final_val_accuracy:.2f}%')
Added early stopping logic to monitor validation loss after each epoch.
Stopped training if validation loss did not improve for 3 consecutive epochs (patience=3).
Saved the best model weights when validation loss improved.
Loaded the best model weights after training stopped.
Results Interpretation

Before Early Stopping: Training accuracy: 95%, Validation accuracy: 75%, Validation loss: 0.45 (increasing)

After Early Stopping: Training accuracy: ~90%, Validation accuracy: 82%, Validation loss: 0.35 (stabilized)

Early stopping helps prevent overfitting by stopping training once validation loss stops improving, leading to better generalization and higher validation accuracy.
Bonus Experiment
Try adding dropout layers to the model along with early stopping to further reduce overfitting.
💡 Hint
Add nn.Dropout layers after the first linear layer with dropout rate around 0.3 and observe if validation accuracy improves further.