Gradient accumulation helps train models with big batches using small memory. Zeroing clears old gradients so new learning is correct.
Gradient accumulation and zeroing in PyTorch
optimizer.zero_grad() for i, data in enumerate(dataloader): outputs = model(data) loss = loss_fn(outputs, labels) loss.backward() if (i + 1) % accumulation_steps == 0: optimizer.step() optimizer.zero_grad()
zero_grad() clears old gradients before new backward pass.
Accumulate gradients over several batches before calling optimizer.step().
optimizer.zero_grad() loss.backward() optimizer.step()
optimizer.zero_grad() for i, data in enumerate(dataloader): outputs = model(data) loss = loss_fn(outputs, labels) loss.backward() if (i + 1) % 4 == 0: optimizer.step() optimizer.zero_grad()
This example shows training a simple linear model on 8 samples. It accumulates gradients over 4 samples before updating weights. We print predictions before and after one epoch to see the change.
import torch import torch.nn as nn import torch.optim as optim # Simple model class SimpleModel(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(2, 1) def forward(self, x): return self.linear(x) # Data: 8 samples, 2 features inputs = torch.tensor([[1.0, 2.0], [2.0, 1.0], [3.0, 4.0], [4.0, 3.0], [5.0, 6.0], [6.0, 5.0], [7.0, 8.0], [8.0, 7.0]]) labels = torch.tensor([[3.0], [3.0], [7.0], [7.0], [11.0], [11.0], [15.0], [15.0]]) model = SimpleModel() loss_fn = nn.MSELoss() optimizer = optim.SGD(model.parameters(), lr=0.01) accumulation_steps = 4 print('Before training:') with torch.no_grad(): preds = model(inputs) print(preds.squeeze().tolist()) optimizer.zero_grad() for i in range(len(inputs)): input_i = inputs[i].unsqueeze(0) label_i = labels[i].unsqueeze(0) output = model(input_i) loss = loss_fn(output, label_i) loss.backward() if (i + 1) % accumulation_steps == 0: optimizer.step() optimizer.zero_grad() print('\nAfter one epoch with gradient accumulation:') with torch.no_grad(): preds = model(inputs) print(preds.squeeze().tolist())
Always call optimizer.zero_grad() before starting to accumulate gradients.
Gradient accumulation lets you use bigger batch sizes without needing more memory.
Forget to zero gradients and you add old gradients again, which can confuse learning.
Gradient accumulation adds gradients over multiple batches before updating model.
Zeroing gradients clears old information to keep learning correct.
This technique helps train with limited memory and control update frequency.