What is retain_graph in PyTorch: Explanation and Usage
retain_graph is a parameter used in backward() to keep the computation graph after computing gradients. By default, PyTorch frees the graph to save memory, but setting retain_graph=True allows multiple backward passes on the same graph.How It Works
When you run backward() in PyTorch, it calculates gradients by tracing back through the computation graph that recorded all operations. Normally, after this backward pass, PyTorch frees this graph to save memory, like cleaning up after finishing a task.
However, sometimes you want to do more than one backward pass on the same graph, like reusing a recipe to bake multiple cakes without rewriting it each time. Setting retain_graph=True tells PyTorch to keep the graph alive after the first backward call, so you can call backward() again on the same graph without errors.
Example
retain_graph=True allows multiple backward passes on the same graph to compute gradients twice.import torch x = torch.tensor(2.0, requires_grad=True) y = x * x # y = x^2 # First backward pass y.backward(retain_graph=True) print(f'Gradient after first backward: {x.grad.item()}') # Second backward pass on the same graph y.backward() print(f'Gradient after second backward: {x.grad.item()}')
When to Use
You use retain_graph=True when you need to call backward() multiple times on the same computation graph. This happens in advanced training scenarios like:
- Computing higher-order derivatives (gradients of gradients).
- Training models with multiple losses that share parts of the graph.
- Custom optimization loops where you accumulate gradients in steps.
Remember, keeping the graph uses more memory, so only use it when necessary.
Key Points
retain_graph=False(default) frees the graph after backward to save memory.- Set
retain_graph=Trueto keep the graph for multiple backward calls. - Useful for higher-order gradients or multiple backward passes.
- Using
retain_graph=Trueunnecessarily can cause higher memory use.