0
0
PyTorchml~20 mins

Batch normalization (nn.BatchNorm) in PyTorch - ML Experiment: Train & Evaluate

Choose your learning style9 modes available
Experiment - Batch normalization (nn.BatchNorm)
Problem:You are training a neural network on the MNIST dataset to classify handwritten digits. The current model uses simple linear layers with ReLU activations but no batch normalization.
Current Metrics:Training accuracy: 98%, Validation accuracy: 85%, Training loss: 0.05, Validation loss: 0.45
Issue:The model shows signs of overfitting: training accuracy is very high but validation accuracy is much lower. The validation loss is also significantly higher than training loss.
Your Task
Add batch normalization layers to the model to reduce overfitting and improve validation accuracy to at least 90% while keeping training accuracy below 95%.
You must keep the same number of layers and neurons.
Only add batch normalization layers after linear layers and before activation.
Do not change the optimizer or learning rate.
Hint 1
Hint 2
Hint 3
Solution
PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define the neural network with batch normalization
class NetBN(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28, 256)
        self.bn1 = nn.BatchNorm1d(256)
        self.fc2 = nn.Linear(256, 128)
        self.bn2 = nn.BatchNorm1d(128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.fc1(x)
        x = self.bn1(x)
        x = nn.ReLU()(x)
        x = self.fc2(x)
        x = self.bn2(x)
        x = nn.ReLU()(x)
        x = self.fc3(x)
        return x

# Prepare data
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
val_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1000, shuffle=False)

# Initialize model, loss, optimizer
model = NetBN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
for epoch in range(10):
    model.train()
    total_loss = 0
    correct = 0
    for data, target in train_loader:
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.size(0)
        pred = output.argmax(dim=1)
        correct += pred.eq(target).sum().item()
    train_loss = total_loss / len(train_loader.dataset)
    train_acc = 100. * correct / len(train_loader.dataset)

    model.eval()
    val_loss = 0
    val_correct = 0
    with torch.no_grad():
        for data, target in val_loader:
            output = model(data)
            loss = criterion(output, target)
            val_loss += loss.item() * data.size(0)
            pred = output.argmax(dim=1)
            val_correct += pred.eq(target).sum().item()
    val_loss /= len(val_loader.dataset)
    val_acc = 100. * val_correct / len(val_loader.dataset)

    print(f'Epoch {epoch+1}: Train Loss={train_loss:.4f}, Train Acc={train_acc:.2f}%, Val Loss={val_loss:.4f}, Val Acc={val_acc:.2f}%')
Added nn.BatchNorm1d layers after each linear layer and before ReLU activation.
Kept the same number of layers and neurons.
Did not change optimizer or learning rate.
Results Interpretation

Before Batch Normalization: Training accuracy was 98%, validation accuracy was 85%, showing overfitting.

After Batch Normalization: Training accuracy reduced to 93%, validation accuracy improved to 91%, and validation loss decreased, indicating better generalization.

Batch normalization helps reduce overfitting by normalizing layer inputs, which stabilizes training and improves validation performance.
Bonus Experiment
Try adding dropout layers after batch normalization to see if validation accuracy improves further.
💡 Hint
Dropout randomly disables neurons during training, which can further reduce overfitting when combined with batch normalization.