0
0
PytorchHow-ToBeginner · 3 min read

How to Use gather in PyTorch: Syntax and Examples

In PyTorch, torch.gather collects values along an axis specified by dim using indices from another tensor. It returns a new tensor where each element is selected from the input tensor according to the indices provided.
📐

Syntax

The torch.gather function has the following syntax:

  • input: The source tensor to gather values from.
  • dim: The dimension along which to index.
  • index: A tensor of indices specifying which elements to gather.

The index tensor must have the same shape as the output tensor, and its values must be valid indices along dim.

python
torch.gather(input, dim, index)
💻

Example

This example shows how to use torch.gather to select elements from a 2D tensor along dimension 1 using an index tensor.

python
import torch

# Create a 2D tensor
input_tensor = torch.tensor([[10, 20, 30],
                             [40, 50, 60]])

# Indices to gather along dim=1
index_tensor = torch.tensor([[2, 1, 0],
                             [0, 2, 1]])

# Gather elements
output_tensor = torch.gather(input_tensor, dim=1, index=index_tensor)

print(output_tensor)
Output
[[30 20 10] [40 60 50]]
⚠️

Common Pitfalls

Common mistakes when using torch.gather include:

  • Using an index tensor with values out of range for the specified dimension, causing runtime errors.
  • Mismatch in shapes between index and the desired output tensor.
  • Confusing dim parameter, which controls the axis along which indices are applied.

Always ensure index values are valid indices for input along dim.

python
import torch

input_tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])

# Wrong: index values out of range (3 is invalid for dim=1 with size 3)
index_wrong = torch.tensor([[0, 3, 1], [1, 0, 2]])

try:
    output_wrong = torch.gather(input_tensor, dim=1, index=index_wrong)
except IndexError as e:
    print(f"Error: {e}")

# Right: valid indices
index_right = torch.tensor([[0, 2, 1], [1, 0, 2]])
output_right = torch.gather(input_tensor, dim=1, index=index_right)
print(output_right)
Output
Error: index 3 is out of bounds for dimension 1 with size 3 tensor([[1, 3, 2], [5, 4, 6]])
📊

Quick Reference

Tips for using torch.gather:

  • Use dim to specify the axis along which you want to pick elements.
  • The index tensor shape must match the output shape you want.
  • Indices in index must be within the range of input.size(dim).
  • Useful for selecting elements based on dynamic indices, like in attention mechanisms or masking.

Key Takeaways

torch.gather selects elements from a tensor along a specified dimension using an index tensor.
The index tensor must have valid indices and match the output shape.
dim parameter controls the axis along which elements are gathered.
Common errors come from out-of-range indices or shape mismatches.
torch.gather is useful for advanced indexing tasks like attention or masking.