This code shows how zeroing gradients works in a simple training step. It clears old gradients, computes new ones, updates weights, and clears gradients again.
import torch
import torch.nn as nn
import torch.optim as optim
# Simple linear model
model = nn.Linear(2, 1)
optimizer = optim.SGD(model.parameters(), lr=0.1)
# Dummy input and target
inputs = torch.tensor([[1.0, 2.0]])
target = torch.tensor([[1.0]])
# Forward pass
output = model(inputs)
# Compute mean squared error loss
loss_fn = nn.MSELoss()
loss = loss_fn(output, target)
# Zero gradients before backward pass
optimizer.zero_grad()
# Backward pass to compute gradients
loss.backward()
# Print gradients before optimizer step
grads_before = [param.grad.clone() for param in model.parameters()]
# Update weights
optimizer.step()
# Zero gradients again for next step
optimizer.zero_grad()
gradients_after = [param.grad for param in model.parameters()]
print("Gradients before optimizer step:")
for g in grads_before:
print(g)
print("\nGradients after zero_grad call:")
for g in gradients_after:
print(g)