0
0
PyTorchml~5 mins

Gradient accumulation and zeroing in PyTorch

Choose your learning style9 modes available
Introduction

Gradient accumulation helps train models with big batches using small memory. Zeroing clears old gradients so new learning is correct.

When your computer memory is too small for large batch training.
When you want to simulate a big batch size by adding small batches.
When training deep neural networks that need stable updates.
When you want to control how often the model updates its knowledge.
When you want to avoid mixing old and new gradient information.
Syntax
PyTorch
optimizer.zero_grad()
for i, data in enumerate(dataloader):
    outputs = model(data)
    loss = loss_fn(outputs, labels)
    loss.backward()
    if (i + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

zero_grad() clears old gradients before new backward pass.

Accumulate gradients over several batches before calling optimizer.step().

Examples
Basic training step: clear gradients, compute gradients, update weights.
PyTorch
optimizer.zero_grad()
loss.backward()
optimizer.step()
Accumulate gradients over 4 batches before updating model.
PyTorch
optimizer.zero_grad()
for i, data in enumerate(dataloader):
    outputs = model(data)
    loss = loss_fn(outputs, labels)
    loss.backward()
    if (i + 1) % 4 == 0:
        optimizer.step()
        optimizer.zero_grad()
Sample Model

This example shows training a simple linear model on 8 samples. It accumulates gradients over 4 samples before updating weights. We print predictions before and after one epoch to see the change.

PyTorch
import torch
import torch.nn as nn
import torch.optim as optim

# Simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(2, 1)
    def forward(self, x):
        return self.linear(x)

# Data: 8 samples, 2 features
inputs = torch.tensor([[1.0, 2.0], [2.0, 1.0], [3.0, 4.0], [4.0, 3.0],
                       [5.0, 6.0], [6.0, 5.0], [7.0, 8.0], [8.0, 7.0]])
labels = torch.tensor([[3.0], [3.0], [7.0], [7.0], [11.0], [11.0], [15.0], [15.0]])

model = SimpleModel()
loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

accumulation_steps = 4

print('Before training:')
with torch.no_grad():
    preds = model(inputs)
    print(preds.squeeze().tolist())

optimizer.zero_grad()
for i in range(len(inputs)):
    input_i = inputs[i].unsqueeze(0)
    label_i = labels[i].unsqueeze(0)
    output = model(input_i)
    loss = loss_fn(output, label_i)
    loss.backward()
    if (i + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

print('\nAfter one epoch with gradient accumulation:')
with torch.no_grad():
    preds = model(inputs)
    print(preds.squeeze().tolist())
OutputSuccess
Important Notes

Always call optimizer.zero_grad() before starting to accumulate gradients.

Gradient accumulation lets you use bigger batch sizes without needing more memory.

Forget to zero gradients and you add old gradients again, which can confuse learning.

Summary

Gradient accumulation adds gradients over multiple batches before updating model.

Zeroing gradients clears old information to keep learning correct.

This technique helps train with limited memory and control update frequency.