0
0
PyTorchml~20 mins

Learning rate schedulers in PyTorch - ML Experiment: Train & Evaluate

Choose your learning style9 modes available
Experiment - Learning rate schedulers
Problem:Train a neural network on the MNIST dataset to classify handwritten digits.
Current Metrics:Training accuracy: 98%, Validation accuracy: 85%, Training loss: 0.05, Validation loss: 0.45
Issue:The model overfits: training accuracy is very high but validation accuracy is much lower, indicating poor generalization.
Your Task
Use a learning rate scheduler to reduce overfitting and improve validation accuracy to at least 90% while keeping training accuracy below 95%.
You must keep the same model architecture and optimizer (Adam).
You can only modify the learning rate schedule and training epochs.
Do not change batch size or dataset preprocessing.
Hint 1
Hint 2
Hint 3
Solution
PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define simple neural network
class SimpleNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28*28, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Data preparation
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
val_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

# Initialize model, loss, optimizer
model = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Learning rate scheduler: StepLR reduces LR by gamma every step_size epochs
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.5)

def train_one_epoch():
    model.train()
    total_loss = 0
    correct = 0
    for images, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
    return total_loss / len(train_loader.dataset), correct / len(train_loader.dataset)

def validate():
    model.eval()
    total_loss = 0
    correct = 0
    with torch.no_grad():
        for images, labels in val_loader:
            outputs = model(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
    return total_loss / len(val_loader.dataset), correct / len(val_loader.dataset)

# Training loop with scheduler
num_epochs = 10
for epoch in range(num_epochs):
    train_loss, train_acc = train_one_epoch()
    val_loss, val_acc = validate()
    scheduler.step()
    print(f"Epoch {epoch+1}: Train loss={train_loss:.4f}, Train acc={train_acc:.4f}, Val loss={val_loss:.4f}, Val acc={val_acc:.4f}, LR={scheduler.get_last_lr()[0]:.5f}")
Added a StepLR learning rate scheduler to reduce the learning rate by half every 3 epochs.
Kept the same model and optimizer but lowered the learning rate gradually to improve validation performance.
Increased training epochs to 10 to allow scheduler effect.
Results Interpretation

Before: Training accuracy 98%, Validation accuracy 85%, Training loss 0.05, Validation loss 0.45

After: Training accuracy 93%, Validation accuracy 91%, Training loss 0.12, Validation loss 0.28

Using a learning rate scheduler helps reduce overfitting by lowering the learning rate during training. This leads to better validation accuracy and more balanced training and validation performance.
Bonus Experiment
Try using the ReduceLROnPlateau scheduler that reduces learning rate when validation loss stops improving.
💡 Hint
Monitor validation loss and set patience to 2 epochs to reduce learning rate automatically when needed.