How to Use backward() in PyTorch for Gradient Computation
In PyTorch, use
tensor.backward() to compute gradients of a scalar tensor with respect to graph leaves. This function accumulates gradients in the .grad attribute of tensors with requires_grad=True.Syntax
The basic syntax to compute gradients is tensor.backward(gradient=None, retain_graph=False, create_graph=False).
tensor: A scalar tensor (single value) whose gradients you want to compute.gradient: Optional tensor specifying the gradient of the output w.r.t. the tensor; usually not needed for scalar outputs.retain_graph: IfTrue, keeps the computation graph for further backward passes.create_graph: IfTrue, constructs the graph for higher order gradients.
python
tensor.backward(gradient=None, retain_graph=False, create_graph=False)
Example
This example shows how to compute gradients of a simple function y = 3x^2 at x=2. After calling backward(), the gradient dy/dx = 6x is stored in x.grad.
python
import torch # Create a tensor with gradient tracking x = torch.tensor(2.0, requires_grad=True) # Define a function y = 3 * x^2 y = 3 * x ** 2 # Compute gradients (dy/dx) y.backward() # Print the gradient stored in x.grad print(f"Gradient dy/dx at x=2: {x.grad.item()}")
Output
Gradient dy/dx at x=2: 12.0
Common Pitfalls
- Calling backward on non-scalar tensors: You must provide a gradient argument if the tensor is not a scalar.
- Not enabling requires_grad: Gradients are only computed for tensors with
requires_grad=True. - Overwriting gradients: Gradients accumulate by default; clear them with
optimizer.zero_grad()or.grad.zero_()before new backward calls. - Retaining computation graph: By default, the graph is freed after backward; set
retain_graph=Trueif you need multiple backward passes.
python
import torch # Wrong: backward on non-scalar without gradient argument x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) y = x * 2 try: y.backward() except RuntimeError as e: print(f"Error: {e}") # Right: provide gradient argument y.backward(torch.tensor([1.0, 1.0, 1.0])) print(f"Gradients: {x.grad}")
Output
Error: grad can be implicitly created only for scalar outputs
Gradients: tensor([2., 2., 2.])
Key Takeaways
Use tensor.backward() to compute gradients for scalar tensors in PyTorch.
Ensure tensors have requires_grad=True to track operations for gradients.
For non-scalar tensors, provide a gradient argument to backward().
Gradients accumulate by default; clear them before new backward calls.
Set retain_graph=True if you need to call backward multiple times on the same graph.