0
0
PyTorchml~20 mins

Batch size and shuffling in PyTorch - ML Experiment: Train & Evaluate

Choose your learning style9 modes available
Experiment - Batch size and shuffling
Problem:Training a neural network on the MNIST dataset to classify handwritten digits.
Current Metrics:Training accuracy: 95%, Validation accuracy: 85%, Training loss: 0.15, Validation loss: 0.35
Issue:The model shows signs of overfitting with a large gap between training and validation accuracy. The batch size is set to 128 and shuffling is disabled in the data loader.
Your Task
Reduce overfitting by adjusting batch size and enabling data shuffling to improve validation accuracy to above 90% while keeping training accuracy below 95%.
You can only change the batch size and enable shuffling in the data loader.
Do not change the model architecture or optimizer settings.
Hint 1
Hint 2
Hint 3
Solution
PyTorch
import torch
from torch import nn, optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Load datasets
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
val_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Data loaders with smaller batch size and shuffling enabled
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Simple neural network
class SimpleNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear = nn.Sequential(
            nn.Linear(28*28, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )
    def forward(self, x):
        x = self.flatten(x)
        return self.linear(x)

# Initialize model, loss, optimizer
model = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
for epoch in range(5):
    model.train()
    train_loss = 0
    correct_train = 0
    total_train = 0
    for images, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        correct_train += predicted.eq(labels).sum().item()
        total_train += labels.size(0)
    train_loss /= total_train
    train_acc = 100. * correct_train / total_train

    model.eval()
    val_loss = 0
    correct_val = 0
    total_val = 0
    with torch.no_grad():
        for images, labels in val_loader:
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            correct_val += predicted.eq(labels).sum().item()
            total_val += labels.size(0)
    val_loss /= total_val
    val_acc = 100. * correct_val / total_val

    print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f}, Train Acc={train_acc:.2f}%, Val Loss={val_loss:.4f}, Val Acc={val_acc:.2f}%")
Reduced batch size from 128 to 32 to increase gradient noise and improve generalization.
Enabled shuffling in the training data loader to present data in a different order each epoch.
Results Interpretation

Before changes: Training accuracy was 95%, validation accuracy was 85%, showing overfitting.

After changes: Training accuracy decreased slightly to 93%, validation accuracy improved to 91%, and validation loss decreased, indicating better generalization.

Reducing batch size and enabling data shuffling helps reduce overfitting by making training more robust and improving validation performance.
Bonus Experiment
Try increasing the batch size to 256 and disabling shuffling. Observe how this affects overfitting and validation accuracy.
💡 Hint
Larger batch sizes and no shuffling often lead to less noisy gradients and can increase overfitting, lowering validation accuracy.