0
0
PytorchHow-ToBeginner · 3 min read

How to Compute Gradients in PyTorch: Simple Guide

In PyTorch, you compute gradients by calling backward() on a tensor that represents a scalar output. Make sure the input tensors have requires_grad=True to track operations for gradient calculation.
📐

Syntax

To compute gradients in PyTorch, you typically use the backward() method on a tensor. This triggers PyTorch's automatic differentiation engine to calculate gradients for all tensors that have requires_grad=True.

Key parts:

  • tensor.backward(): Computes gradients of the tensor with respect to graph leaves.
  • requires_grad=True: Enables gradient tracking on tensors.
  • tensor.grad: Holds the computed gradient after backward() is called.
python
import torch

x = torch.tensor(2.0, requires_grad=True)  # Track gradients
y = x ** 2  # y = x squared

# Compute gradient of y w.r.t x
y.backward()
print(x.grad)  # Prints gradient dy/dx = 2*x = 4.0
Output
tensor(4.)
💻

Example

This example shows how to compute the gradient of a simple function y = x^2 + 3x + 1 at x=2. It demonstrates setting requires_grad=True, performing operations, calling backward(), and accessing the gradient.

python
import torch

# Create tensor with gradient tracking
x = torch.tensor(2.0, requires_grad=True)

# Define function y = x^2 + 3x + 1
y = x**2 + 3*x + 1

# Compute gradients
y.backward()

# Print gradient dy/dx at x=2
print(f"Gradient at x=2: {x.grad.item()}")
Output
Gradient at x=2: 7.0
⚠️

Common Pitfalls

Common mistakes when computing gradients in PyTorch include:

  • Not setting requires_grad=True on input tensors, so no gradients are computed.
  • Calling backward() on non-scalar tensors without specifying gradient argument.
  • Reusing tensors without zeroing gradients, causing accumulation.

Always zero gradients before new backward passes if reusing tensors.

python
import torch

# Wrong: requires_grad not set
x = torch.tensor(2.0)
y = x**2
try:
    y.backward()
except RuntimeError as e:
    print(f"Error: {e}")

# Right: requires_grad=True
x = torch.tensor(2.0, requires_grad=True)
y = x**2

y.backward()
print(f"Gradient: {x.grad}")
Output
Error: element 0 of tensors does not require grad and does not have a grad_fn Gradient: tensor(4.)
📊

Quick Reference

ConceptDescription
requires_gradSet to True to track operations for gradients
backward()Computes gradients of a scalar output tensor
tensor.gradHolds the gradient after backward()
zero_grad()Clears old gradients before new backward pass
Non-scalar backwardRequires gradient argument to backward()

Key Takeaways

Set requires_grad=True on tensors to enable gradient tracking.
Call backward() on a scalar tensor to compute gradients.
Access gradients via the .grad attribute of tensors.
Zero gradients before new backward passes to avoid accumulation.
backward() on non-scalar tensors needs a gradient argument.