0
0
PytorchConceptBeginner · 3 min read

What is retain_graph in PyTorch: Explanation and Usage

In PyTorch, retain_graph is a parameter used in backward() to keep the computation graph after computing gradients. By default, PyTorch frees the graph to save memory, but setting retain_graph=True allows multiple backward passes on the same graph.
⚙️

How It Works

When you run backward() in PyTorch, it calculates gradients by tracing back through the computation graph that recorded all operations. Normally, after this backward pass, PyTorch frees this graph to save memory, like cleaning up after finishing a task.

However, sometimes you want to do more than one backward pass on the same graph, like reusing a recipe to bake multiple cakes without rewriting it each time. Setting retain_graph=True tells PyTorch to keep the graph alive after the first backward call, so you can call backward() again on the same graph without errors.

💻

Example

This example shows how retain_graph=True allows multiple backward passes on the same graph to compute gradients twice.
python
import torch

x = torch.tensor(2.0, requires_grad=True)
y = x * x  # y = x^2

# First backward pass
y.backward(retain_graph=True)
print(f'Gradient after first backward: {x.grad.item()}')

# Second backward pass on the same graph
y.backward()
print(f'Gradient after second backward: {x.grad.item()}')
Output
Gradient after first backward: 4.0 Gradient after second backward: 8.0
🎯

When to Use

You use retain_graph=True when you need to call backward() multiple times on the same computation graph. This happens in advanced training scenarios like:

  • Computing higher-order derivatives (gradients of gradients).
  • Training models with multiple losses that share parts of the graph.
  • Custom optimization loops where you accumulate gradients in steps.

Remember, keeping the graph uses more memory, so only use it when necessary.

Key Points

  • retain_graph=False (default) frees the graph after backward to save memory.
  • Set retain_graph=True to keep the graph for multiple backward calls.
  • Useful for higher-order gradients or multiple backward passes.
  • Using retain_graph=True unnecessarily can cause higher memory use.

Key Takeaways

retain_graph=True keeps the computation graph after backward for reuse.
Use retain_graph when you need multiple backward passes on the same graph.
Default retain_graph=False frees memory by deleting the graph after backward.
Retaining the graph increases memory usage, so use it only when needed.
Common in advanced training like higher-order gradients or multiple losses.