0
0
PytorchHow-ToBeginner · 4 min read

How to Use Gradient Accumulation in PyTorch for Large Batch Training

In PyTorch, use optimizer.zero_grad() once every accumulation_steps batches, call loss.backward() on each batch, and call optimizer.step() only after accumulating gradients over multiple batches. This technique lets you simulate a larger batch size by accumulating gradients before updating model weights.
📐

Syntax

Gradient accumulation involves these key steps inside your training loop:

  • optimizer.zero_grad(): Clears old gradients. Call this once before starting accumulation.
  • loss.backward(): Computes gradients for the current batch and adds them to existing gradients.
  • optimizer.step(): Updates model weights using accumulated gradients. Call this after a set number of batches.

You control how many batches to accumulate with accumulation_steps.

python
accumulation_steps = 4
optimizer.zero_grad()
for i, (inputs, targets) in enumerate(dataloader):
    outputs = model(inputs)
    loss = loss_fn(outputs, targets) / accumulation_steps
    loss.backward()
    if (i + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()
💻

Example

This example shows training a simple model on dummy data using gradient accumulation with 2 steps. It prints loss every update.

python
import torch
import torch.nn as nn
import torch.optim as optim

# Simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 1)
    def forward(self, x):
        return self.linear(x)

# Dummy data
inputs = torch.randn(8, 10)
targets = torch.randn(8, 1)

model = SimpleModel()
optimizer = optim.SGD(model.parameters(), lr=0.1)
loss_fn = nn.MSELoss()

accumulation_steps = 2
optimizer.zero_grad()

for i in range(8):
    input_batch = inputs[i].unsqueeze(0)
    target_batch = targets[i].unsqueeze(0)
    output = model(input_batch)
    loss = loss_fn(output, target_batch) / accumulation_steps
    loss.backward()
    if (i + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()
        print(f"Step {i+1}: Loss = {loss.item()*accumulation_steps:.4f}")
Output
Step 2: Loss = 0.9649 Step 4: Loss = 0.2043 Step 6: Loss = 0.0863 Step 8: Loss = 0.0224
⚠️

Common Pitfalls

Common mistakes when using gradient accumulation include:

  • Not dividing the loss by accumulation_steps, which causes gradients to be too large.
  • Calling optimizer.zero_grad() too often, which clears gradients before accumulation.
  • Forgetting to call optimizer.step() after the last batch if total batches are not divisible by accumulation_steps.
python
accumulation_steps = 4
optimizer.zero_grad()
for i, (inputs, targets) in enumerate(dataloader):
    outputs = model(inputs)
    loss = loss_fn(outputs, targets)  # Missing division by accumulation_steps
    loss.backward()
    if (i + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()
# Wrong: loss not scaled, gradients too large

# Correct way:
optimizer.zero_grad()
for i, (inputs, targets) in enumerate(dataloader):
    outputs = model(inputs)
    loss = loss_fn(outputs, targets) / accumulation_steps  # Scale loss
    loss.backward()
    if (i + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()
📊

Quick Reference

StepActionNotes
1Call optimizer.zero_grad()Clear gradients before accumulation
2For each batch: compute loss and call loss.backward()Divide loss by accumulation_steps
3Every accumulation_steps batches: call optimizer.step()Update weights once per accumulation
4Call optimizer.zero_grad() after optimizer.step()Prepare for next accumulation cycle

Key Takeaways

Divide loss by accumulation_steps to keep gradient scale correct.
Call optimizer.zero_grad() only once per accumulation cycle, not every batch.
Call optimizer.step() after accumulating gradients over multiple batches.
Remember to handle leftover batches if total batches not divisible by accumulation_steps.
Gradient accumulation simulates larger batch sizes without extra memory.