How to Use Gradient Accumulation in PyTorch for Large Batch Training
In PyTorch, use
optimizer.zero_grad() once every accumulation_steps batches, call loss.backward() on each batch, and call optimizer.step() only after accumulating gradients over multiple batches. This technique lets you simulate a larger batch size by accumulating gradients before updating model weights.Syntax
Gradient accumulation involves these key steps inside your training loop:
optimizer.zero_grad(): Clears old gradients. Call this once before starting accumulation.loss.backward(): Computes gradients for the current batch and adds them to existing gradients.optimizer.step(): Updates model weights using accumulated gradients. Call this after a set number of batches.
You control how many batches to accumulate with accumulation_steps.
python
accumulation_steps = 4 optimizer.zero_grad() for i, (inputs, targets) in enumerate(dataloader): outputs = model(inputs) loss = loss_fn(outputs, targets) / accumulation_steps loss.backward() if (i + 1) % accumulation_steps == 0: optimizer.step() optimizer.zero_grad()
Example
This example shows training a simple model on dummy data using gradient accumulation with 2 steps. It prints loss every update.
python
import torch import torch.nn as nn import torch.optim as optim # Simple model class SimpleModel(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(10, 1) def forward(self, x): return self.linear(x) # Dummy data inputs = torch.randn(8, 10) targets = torch.randn(8, 1) model = SimpleModel() optimizer = optim.SGD(model.parameters(), lr=0.1) loss_fn = nn.MSELoss() accumulation_steps = 2 optimizer.zero_grad() for i in range(8): input_batch = inputs[i].unsqueeze(0) target_batch = targets[i].unsqueeze(0) output = model(input_batch) loss = loss_fn(output, target_batch) / accumulation_steps loss.backward() if (i + 1) % accumulation_steps == 0: optimizer.step() optimizer.zero_grad() print(f"Step {i+1}: Loss = {loss.item()*accumulation_steps:.4f}")
Output
Step 2: Loss = 0.9649
Step 4: Loss = 0.2043
Step 6: Loss = 0.0863
Step 8: Loss = 0.0224
Common Pitfalls
Common mistakes when using gradient accumulation include:
- Not dividing the loss by
accumulation_steps, which causes gradients to be too large. - Calling
optimizer.zero_grad()too often, which clears gradients before accumulation. - Forgetting to call
optimizer.step()after the last batch if total batches are not divisible byaccumulation_steps.
python
accumulation_steps = 4 optimizer.zero_grad() for i, (inputs, targets) in enumerate(dataloader): outputs = model(inputs) loss = loss_fn(outputs, targets) # Missing division by accumulation_steps loss.backward() if (i + 1) % accumulation_steps == 0: optimizer.step() optimizer.zero_grad() # Wrong: loss not scaled, gradients too large # Correct way: optimizer.zero_grad() for i, (inputs, targets) in enumerate(dataloader): outputs = model(inputs) loss = loss_fn(outputs, targets) / accumulation_steps # Scale loss loss.backward() if (i + 1) % accumulation_steps == 0: optimizer.step() optimizer.zero_grad()
Quick Reference
| Step | Action | Notes |
|---|---|---|
| 1 | Call optimizer.zero_grad() | Clear gradients before accumulation |
| 2 | For each batch: compute loss and call loss.backward() | Divide loss by accumulation_steps |
| 3 | Every accumulation_steps batches: call optimizer.step() | Update weights once per accumulation |
| 4 | Call optimizer.zero_grad() after optimizer.step() | Prepare for next accumulation cycle |
Key Takeaways
Divide loss by accumulation_steps to keep gradient scale correct.
Call optimizer.zero_grad() only once per accumulation cycle, not every batch.
Call optimizer.step() after accumulating gradients over multiple batches.
Remember to handle leftover batches if total batches not divisible by accumulation_steps.
Gradient accumulation simulates larger batch sizes without extra memory.