0
0
PytorchHow-ToBeginner · 3 min read

How to Use optimizer.zero_grad in PyTorch: Clear Guide

In PyTorch, use optimizer.zero_grad() to clear old gradients before computing new ones during training. This prevents gradients from accumulating across batches, ensuring correct weight updates.
📐

Syntax

The method optimizer.zero_grad() is called on an optimizer object to reset all gradients of model parameters to zero. This is necessary before calling loss.backward() to compute fresh gradients for the current batch.

  • optimizer: The optimizer instance managing model parameters (e.g., SGD, Adam).
  • zero_grad(): Method that sets all parameter gradients to zero.
python
optimizer.zero_grad()
💻

Example

This example shows a simple training step where optimizer.zero_grad() is used to clear gradients before backpropagation and optimizer step.

python
import torch
import torch.nn as nn
import torch.optim as optim

# Simple model
model = nn.Linear(2, 1)
optimizer = optim.SGD(model.parameters(), lr=0.1)

# Dummy input and target
inputs = torch.tensor([[1.0, 2.0]])
target = torch.tensor([[1.0]])

# Forward pass
output = model(inputs)
loss = nn.MSELoss()(output, target)

# Clear old gradients
optimizer.zero_grad()

# Backward pass
loss.backward()

# Update weights
optimizer.step()

print(f"Loss: {loss.item():.4f}")
Output
Loss: 0.2500
⚠️

Common Pitfalls

One common mistake is forgetting to call optimizer.zero_grad() before loss.backward(). This causes gradients to accumulate from multiple batches, leading to incorrect updates and unstable training.

Another pitfall is calling zero_grad() after loss.backward() or optimizer.step(), which will clear gradients too late or too early.

python
import torch
import torch.nn as nn
import torch.optim as optim

model = nn.Linear(2, 1)
optimizer = optim.SGD(model.parameters(), lr=0.1)
inputs = torch.tensor([[1.0, 2.0]])
target = torch.tensor([[1.0]])

output = model(inputs)
loss = nn.MSELoss()(output, target)

# WRONG: Missing zero_grad call
loss.backward()
optimizer.step()

# CORRECT:
optimizer.zero_grad()
loss.backward()
optimizer.step()
📊

Quick Reference

  • Call optimizer.zero_grad() before loss.backward() each training step.
  • Clears gradients to avoid accumulation across batches.
  • Ensures correct gradient computation and weight updates.
  • Use with any PyTorch optimizer (SGD, Adam, etc.).

Key Takeaways

Always call optimizer.zero_grad() before loss.backward() to reset gradients.
Not zeroing gradients causes accumulation and incorrect training.
optimizer.zero_grad() works with all PyTorch optimizers.
Place zero_grad() at the start of each training iteration.
Correct gradient management leads to stable and accurate model updates.