0
0
PytorchHow-ToBeginner · 3 min read

How to Use backward() in PyTorch for Gradient Computation

In PyTorch, use tensor.backward() to compute gradients of a scalar tensor with respect to graph leaves. This function accumulates gradients in the .grad attribute of tensors with requires_grad=True.
📐

Syntax

The basic syntax to compute gradients is tensor.backward(gradient=None, retain_graph=False, create_graph=False).

  • tensor: A scalar tensor (single value) whose gradients you want to compute.
  • gradient: Optional tensor specifying the gradient of the output w.r.t. the tensor; usually not needed for scalar outputs.
  • retain_graph: If True, keeps the computation graph for further backward passes.
  • create_graph: If True, constructs the graph for higher order gradients.
python
tensor.backward(gradient=None, retain_graph=False, create_graph=False)
💻

Example

This example shows how to compute gradients of a simple function y = 3x^2 at x=2. After calling backward(), the gradient dy/dx = 6x is stored in x.grad.

python
import torch

# Create a tensor with gradient tracking
x = torch.tensor(2.0, requires_grad=True)

# Define a function y = 3 * x^2
y = 3 * x ** 2

# Compute gradients (dy/dx)
y.backward()

# Print the gradient stored in x.grad
print(f"Gradient dy/dx at x=2: {x.grad.item()}")
Output
Gradient dy/dx at x=2: 12.0
⚠️

Common Pitfalls

  • Calling backward on non-scalar tensors: You must provide a gradient argument if the tensor is not a scalar.
  • Not enabling requires_grad: Gradients are only computed for tensors with requires_grad=True.
  • Overwriting gradients: Gradients accumulate by default; clear them with optimizer.zero_grad() or .grad.zero_() before new backward calls.
  • Retaining computation graph: By default, the graph is freed after backward; set retain_graph=True if you need multiple backward passes.
python
import torch

# Wrong: backward on non-scalar without gradient argument
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = x * 2
try:
    y.backward()
except RuntimeError as e:
    print(f"Error: {e}")

# Right: provide gradient argument
y.backward(torch.tensor([1.0, 1.0, 1.0]))
print(f"Gradients: {x.grad}")
Output
Error: grad can be implicitly created only for scalar outputs Gradients: tensor([2., 2., 2.])

Key Takeaways

Use tensor.backward() to compute gradients for scalar tensors in PyTorch.
Ensure tensors have requires_grad=True to track operations for gradients.
For non-scalar tensors, provide a gradient argument to backward().
Gradients accumulate by default; clear them before new backward calls.
Set retain_graph=True if you need to call backward multiple times on the same graph.