How to Use optimizer.zero_grad in PyTorch: Clear Guide
In PyTorch, use
optimizer.zero_grad() to clear old gradients before computing new ones during training. This prevents gradients from accumulating across batches, ensuring correct weight updates.Syntax
The method optimizer.zero_grad() is called on an optimizer object to reset all gradients of model parameters to zero. This is necessary before calling loss.backward() to compute fresh gradients for the current batch.
optimizer: The optimizer instance managing model parameters (e.g., SGD, Adam).zero_grad(): Method that sets all parameter gradients to zero.
python
optimizer.zero_grad()
Example
This example shows a simple training step where optimizer.zero_grad() is used to clear gradients before backpropagation and optimizer step.
python
import torch import torch.nn as nn import torch.optim as optim # Simple model model = nn.Linear(2, 1) optimizer = optim.SGD(model.parameters(), lr=0.1) # Dummy input and target inputs = torch.tensor([[1.0, 2.0]]) target = torch.tensor([[1.0]]) # Forward pass output = model(inputs) loss = nn.MSELoss()(output, target) # Clear old gradients optimizer.zero_grad() # Backward pass loss.backward() # Update weights optimizer.step() print(f"Loss: {loss.item():.4f}")
Output
Loss: 0.2500
Common Pitfalls
One common mistake is forgetting to call optimizer.zero_grad() before loss.backward(). This causes gradients to accumulate from multiple batches, leading to incorrect updates and unstable training.
Another pitfall is calling zero_grad() after loss.backward() or optimizer.step(), which will clear gradients too late or too early.
python
import torch import torch.nn as nn import torch.optim as optim model = nn.Linear(2, 1) optimizer = optim.SGD(model.parameters(), lr=0.1) inputs = torch.tensor([[1.0, 2.0]]) target = torch.tensor([[1.0]]) output = model(inputs) loss = nn.MSELoss()(output, target) # WRONG: Missing zero_grad call loss.backward() optimizer.step() # CORRECT: optimizer.zero_grad() loss.backward() optimizer.step()
Quick Reference
- Call
optimizer.zero_grad()beforeloss.backward()each training step. - Clears gradients to avoid accumulation across batches.
- Ensures correct gradient computation and weight updates.
- Use with any PyTorch optimizer (SGD, Adam, etc.).
Key Takeaways
Always call optimizer.zero_grad() before loss.backward() to reset gradients.
Not zeroing gradients causes accumulation and incorrect training.
optimizer.zero_grad() works with all PyTorch optimizers.
Place zero_grad() at the start of each training iteration.
Correct gradient management leads to stable and accurate model updates.