0
0
PyTorchml~20 mins

Gradient accumulation and zeroing in PyTorch - ML Experiment: Train & Evaluate

Choose your learning style9 modes available
Experiment - Gradient accumulation and zeroing
Problem:You are training a neural network on a small GPU that cannot fit a large batch size. Currently, you use a batch size of 16, but the model trains slowly and the gradients are reset after every batch. This limits the effective batch size and slows learning.
Current Metrics:Training loss decreases slowly, validation accuracy reaches about 70% after 10 epochs.
Issue:The model trains slowly because the batch size is small. Gradients are zeroed after every batch, so the model cannot accumulate gradients over multiple batches to simulate a larger batch size.
Your Task
Implement gradient accumulation to simulate a larger batch size of 64 by accumulating gradients over 4 batches before updating weights. Ensure gradients are zeroed correctly to avoid incorrect updates.
Do not change the batch size from 16.
Do not change the model architecture.
Use PyTorch for implementation.
Hint 1
Hint 2
Hint 3
Solution
PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Simple dataset
X = torch.randn(1000, 10)
y = (X.sum(dim=1) > 0).long()
dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

# Simple model
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(10, 2)
    def forward(self, x):
        return self.fc(x)

model = SimpleNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)

accumulation_steps = 4  # accumulate gradients over 4 batches

for epoch in range(10):
    running_loss = 0.0
    optimizer.zero_grad()  # zero gradients before starting epoch
    for i, (inputs, labels) in enumerate(dataloader):
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss = loss / accumulation_steps  # normalize loss
        loss.backward()  # accumulate gradients

        if (i + 1) % accumulation_steps == 0:
            optimizer.step()  # update weights
            optimizer.zero_grad()  # zero gradients after update

        running_loss += loss.item() * accumulation_steps

    print(f"Epoch {epoch+1}, Loss: {running_loss / len(dataloader):.4f}")
Added gradient accumulation by dividing loss by accumulation_steps and calling loss.backward() multiple times before optimizer.step().
Moved optimizer.zero_grad() to run only after optimizer.step() to avoid zeroing gradients too early.
Used a counter to update weights every 4 batches to simulate batch size of 64.
Results Interpretation

Before: Batch size 16, gradients zeroed every batch, slow training, validation accuracy ~70%.

After: Effective batch size 64 via gradient accumulation, gradients zeroed only after weight update, faster training, validation accuracy ~78%.

Gradient accumulation allows simulating larger batch sizes on limited hardware by accumulating gradients over multiple smaller batches before updating weights. Proper zeroing of gradients is essential to avoid incorrect updates.
Bonus Experiment
Try adding gradient clipping after accumulation to prevent exploding gradients and observe its effect on training stability.
💡 Hint
Use torch.nn.utils.clip_grad_norm_ after loss.backward() and before optimizer.step() to clip gradients.