0
0
PyTorchml~20 mins

Why regularization controls overfitting in PyTorch - Experiment to Prove It

Choose your learning style9 modes available
Experiment - Why regularization controls overfitting
Problem:We want to train a neural network to classify handwritten digits from the MNIST dataset. The current model fits the training data very well but performs poorly on new data.
Current Metrics:Training accuracy: 98%, Validation accuracy: 82%, Training loss: 0.05, Validation loss: 0.45
Issue:The model is overfitting: it learns the training data too well but does not generalize to validation data.
Your Task
Reduce overfitting by applying regularization techniques so that validation accuracy improves to at least 90% while keeping training accuracy below 95%.
You can only add L2 weight decay and dropout layers to the model.
Do not change the model architecture or dataset.
Keep the number of training epochs the same.
Hint 1
Hint 2
Hint 3
Solution
PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define the neural network with dropout
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28, 256)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# Load data
transform = transforms.ToTensor()
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
val_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1000, shuffle=False)

# Initialize model, loss, optimizer with weight decay (L2 regularization)
model = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)  # L2 weight decay

def train():
    model.train()
    total_loss = 0
    correct = 0
    for data, target in train_loader:
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.size(0)
        pred = output.argmax(dim=1)
        correct += pred.eq(target).sum().item()
    return total_loss / len(train_loader.dataset), correct / len(train_loader.dataset)

def validate():
    model.eval()
    total_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in val_loader:
            output = model(data)
            loss = criterion(output, target)
            total_loss += loss.item() * data.size(0)
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
    return total_loss / len(val_loader.dataset), correct / len(val_loader.dataset)

# Train for 10 epochs
for epoch in range(10):
    train_loss, train_acc = train()
    val_loss, val_acc = validate()
    print(f'Epoch {epoch+1}: Train loss {train_loss:.4f}, Train acc {train_acc:.4f}, Val loss {val_loss:.4f}, Val acc {val_acc:.4f}')
Added dropout layer with 0.5 probability after the first fully connected layer to reduce neuron co-adaptation.
Added L2 weight decay (1e-4) to the Adam optimizer to penalize large weights and encourage simpler models.
Results Interpretation

Before regularization: Training accuracy was 98%, validation accuracy was 82%. The large gap shows overfitting.

After regularization: Training accuracy dropped to 93%, validation accuracy improved to 91%. The gap narrowed, showing better generalization.

Regularization techniques like dropout and L2 weight decay help prevent the model from memorizing training data. This leads to better performance on new, unseen data by encouraging simpler, more general patterns.
Bonus Experiment
Try using early stopping to stop training when validation loss stops improving to further reduce overfitting.
💡 Hint
Monitor validation loss each epoch and stop training if it does not improve for 3 consecutive epochs.