0
0
PyTorchml~5 mins

Detaching from computation graph in PyTorch

Choose your learning style9 modes available
Introduction
Detaching from the computation graph stops tracking operations on a tensor. This helps save memory and avoid unwanted gradient calculations.
When you want to use a tensor's value but do not want to update it during training.
When you want to convert a tensor to a NumPy array without keeping track of gradients.
When you want to freeze part of a neural network during training.
When you want to do some calculations for logging or visualization without affecting training.
When you want to speed up inference by avoiding gradient tracking.
Syntax
PyTorch
detached_tensor = tensor.detach()
The detached tensor shares the same data but is not connected to the computation graph.
Operations on the detached tensor will not be tracked for gradients.
Examples
Here, y_detached is a tensor with the same values as y but detached from the graph.
PyTorch
import torch
x = torch.tensor([1.0, 2.0], requires_grad=True)
y = x * 2
y_detached = y.detach()
Operations on detached tensors do not track gradients, so requires_grad is False.
PyTorch
z = y_detached + 3
print(z.requires_grad)
If the original tensor does not require gradients, detach() returns a tensor with requires_grad=False.
PyTorch
x = torch.tensor([1.0, 2.0])
y = x.detach()
Sample Model
This program shows how detaching stops gradient tracking. The detached tensor has requires_grad=False, so it won't track gradients or allow backward calls.
PyTorch
import torch

# Create a tensor with gradient tracking
input_tensor = torch.tensor([2.0, 3.0], requires_grad=True)

# Perform an operation tracked by autograd
output_tensor = input_tensor * 4

print(f"Before detach: requires_grad = {output_tensor.requires_grad}")

# Detach the output tensor from the computation graph
output_detached = output_tensor.detach()

print(f"After detach: requires_grad = {output_detached.requires_grad}")

# Try to do backward on detached tensor (should raise error if uncommented)
# output_detached.backward(torch.tensor([1.0, 1.0]))

# Do backward on original output tensor
output_tensor.sum().backward()

print(f"Gradient of input_tensor: {input_tensor.grad}")
OutputSuccess
Important Notes
Detaching a tensor is a constant-time operation and does not copy data.
The detached tensor shares the same memory as the original tensor but is not tracked for gradients.
Common mistake: expecting gradients to flow through detached tensors, but they do not.
Use detach() when you want to stop gradients from flowing back through certain parts of your model.
Summary
detach() creates a tensor that shares data but is not tracked for gradients.
Use detach() to stop gradient calculations and save memory.
Detached tensors cannot be used to compute gradients or call backward().