0
0
PytorchHow-ToBeginner · 3 min read

How to Detach Tensor in PyTorch: Simple Guide

In PyTorch, you can detach a tensor from the computation graph using tensor.detach(). This creates a new tensor that shares the same data but does not track gradients, so it won't affect backpropagation.
📐

Syntax

The basic syntax to detach a tensor is detached_tensor = tensor.detach(). Here, tensor is the original tensor that requires gradients, and detached_tensor is the new tensor that shares the same data but is disconnected from the gradient tracking.

  • tensor: The original tensor with gradient tracking.
  • detach(): Method that returns a new tensor detached from the computation graph.
  • detached_tensor: The resulting tensor without gradient tracking.
python
detached_tensor = tensor.detach()
💻

Example

This example shows how to create a tensor with gradient tracking, detach it, and verify that the detached tensor does not require gradients.

python
import torch

# Create a tensor with requires_grad=True
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)

# Detach the tensor
x_detached = x.detach()

# Print whether each tensor requires gradients
print(f"Original tensor requires_grad: {x.requires_grad}")
print(f"Detached tensor requires_grad: {x_detached.requires_grad}")
Output
Original tensor requires_grad: True Detached tensor requires_grad: False
⚠️

Common Pitfalls

A common mistake is to try to modify the original tensor after detaching, expecting the detached tensor to track those changes or gradients. However, the detached tensor shares the same data but does not track gradients, so changes to the original tensor's data will reflect in the detached tensor, but gradients won't flow back.

Also, using tensor.detach_() modifies the tensor in-place, which can cause unexpected behavior if you still need the original tensor with gradient tracking.

python
import torch

x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)

# Wrong: modifying original after detach
x_detached = x.detach()
x[0] = 10.0
print(x_detached)  # Shows updated data but no gradients

# Right: clone before detach to avoid shared data
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
x_detached = x.clone().detach()
x[0] = 10.0
print(x_detached)  # Original data preserved
Output
tensor([10., 2., 3.]) tensor([1., 2., 3.])
📊

Quick Reference

MethodDescription
tensor.detach()Returns a new tensor detached from the computation graph.
tensor.detach_()Detaches the tensor in-place, modifying the original tensor.
tensor.clone().detach()Creates a copy of the tensor and detaches it, avoiding shared data changes.

Key Takeaways

Use tensor.detach() to get a tensor that shares data but does not track gradients.
Detached tensors do not affect backpropagation and have requires_grad=False.
Modifying the original tensor after detach affects the detached tensor's data unless cloned first.
Avoid using detach_() unless you want to modify the tensor in-place.
Use clone().detach() to safely detach without sharing data changes.