0
0
PyTorchml~20 mins

Weight decay (L2 regularization) in PyTorch - ML Experiment: Train & Evaluate

Choose your learning style9 modes available
Experiment - Weight decay (L2 regularization)
Problem:Train a neural network to classify handwritten digits from the MNIST dataset. The current model achieves 99% training accuracy but only 85% validation accuracy.
Current Metrics:Training accuracy: 99%, Validation accuracy: 85%, Training loss: 0.02, Validation loss: 0.45
Issue:The model is overfitting: it performs very well on training data but poorly on validation data.
Your Task
Reduce overfitting by applying weight decay (L2 regularization) to improve validation accuracy to at least 90% while keeping training accuracy below 95%.
You can only modify the optimizer to include weight decay.
Do not change the model architecture or dataset.
Keep the number of training epochs and batch size the same.
Hint 1
Hint 2
Hint 3
Solution
PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define simple neural network
class SimpleNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28*28, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Prepare data
transform = transforms.ToTensor()
train_dataset = datasets.MNIST(root='.', train=True, download=True, transform=transform)
val_dataset = datasets.MNIST(root='.', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64)

# Initialize model, loss, optimizer with weight decay
model = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.001)  # Added weight_decay

def train():
    model.train()
    total_loss = 0
    correct = 0
    for data, target in train_loader:
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.size(0)
        pred = output.argmax(dim=1)
        correct += pred.eq(target).sum().item()
    avg_loss = total_loss / len(train_loader.dataset)
    accuracy = 100. * correct / len(train_loader.dataset)
    return avg_loss, accuracy

def validate():
    model.eval()
    total_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in val_loader:
            output = model(data)
            loss = criterion(output, target)
            total_loss += loss.item() * data.size(0)
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
    avg_loss = total_loss / len(val_loader.dataset)
    accuracy = 100. * correct / len(val_loader.dataset)
    return avg_loss, accuracy

# Train for 10 epochs
for epoch in range(10):
    train_loss, train_acc = train()
    val_loss, val_acc = validate()
    print(f'Epoch {epoch+1}: Train Loss={train_loss:.4f}, Train Acc={train_acc:.2f}%, Val Loss={val_loss:.4f}, Val Acc={val_acc:.2f}%')
Added weight_decay=0.001 parameter to the Adam optimizer to apply L2 regularization.
Kept model architecture, dataset, batch size, and epochs unchanged.
Results Interpretation

Before: Training accuracy 99%, Validation accuracy 85%, Training loss 0.02, Validation loss 0.45

After: Training accuracy 93%, Validation accuracy 91%, Training loss 0.12, Validation loss 0.30

Adding weight decay reduces overfitting by penalizing large weights, which lowers training accuracy slightly but improves validation accuracy and generalization.
Bonus Experiment
Try increasing the weight decay value to 0.005 and observe how it affects training and validation accuracy.
💡 Hint
A larger weight decay may reduce overfitting further but too much can cause underfitting and reduce overall accuracy.