0
0
PyTorchml~5 mins

Gradient accumulation in PyTorch

Choose your learning style9 modes available
Introduction
Gradient accumulation helps train models with large batch sizes by splitting them into smaller steps, saving memory and keeping training stable.
When your computer's memory is too small to handle a large batch at once.
When you want to simulate training with a big batch but only have resources for smaller batches.
When training large models that need stable updates but limited GPU memory.
When you want to improve training stability by averaging gradients over several mini-batches.
Syntax
PyTorch
optimizer.zero_grad()
for i, data in enumerate(dataloader):
    outputs = model(data.inputs)
    loss = loss_function(outputs, data.labels) / accumulation_steps
    loss.backward()
    if (i + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()
Divide the loss by the number of accumulation steps to average gradients correctly.
Call optimizer.step() only after accumulating gradients over several mini-batches.
Examples
Accumulate gradients over 4 mini-batches before updating model weights.
PyTorch
accumulation_steps = 4
optimizer.zero_grad()
for i, (inputs, labels) in enumerate(dataloader):
    outputs = model(inputs)
    loss = criterion(outputs, labels) / accumulation_steps
    loss.backward()
    if (i + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()
Here, gradients are accumulated over 2 batches to simulate a larger batch size.
PyTorch
accumulation_steps = 2
optimizer.zero_grad()
for i, batch in enumerate(dataloader):
    inputs, targets = batch
    outputs = model(inputs)
    loss = loss_fn(outputs, targets) / accumulation_steps
    loss.backward()
    if (i + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()
Sample Model
This code trains a simple linear model on random data using gradient accumulation over 4 mini-batches before updating weights. It also handles the last few batches if the dataset size is not divisible by accumulation_steps.
PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Simple dataset
inputs = torch.randn(20, 5)
labels = torch.randn(20, 1)
dataset = TensorDataset(inputs, labels)
dataloader = DataLoader(dataset, batch_size=2)

# Simple model
model = nn.Linear(5, 1)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)

accumulation_steps = 4
optimizer.zero_grad()

for i, (x, y) in enumerate(dataloader):
    outputs = model(x)
    loss = criterion(outputs, y) / accumulation_steps
    loss.backward()
    if (i + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

# Handle the case where the last batch is not a multiple of accumulation_steps
if (i + 1) % accumulation_steps != 0:
    optimizer.step()
    optimizer.zero_grad()

# Check final loss on the dataset
with torch.no_grad():
    preds = model(inputs)
    final_loss = criterion(preds, labels).item()

print(f"Final loss after gradient accumulation: {final_loss:.4f}")
OutputSuccess
Important Notes
Always divide the loss by accumulation_steps to keep gradient scale correct.
Remember to call optimizer.zero_grad() after each optimizer.step() to reset gradients.
If dataset size is not divisible by accumulation_steps, handle the last few batches carefully.
Summary
Gradient accumulation splits a large batch into smaller steps to save memory.
It averages gradients over multiple mini-batches before updating model weights.
Useful for training large models or when GPU memory is limited.