How to Use gather in PyTorch: Syntax and Examples
In PyTorch,
torch.gather collects values along an axis specified by dim using indices from another tensor. It returns a new tensor where each element is selected from the input tensor according to the indices provided.Syntax
The torch.gather function has the following syntax:
input: The source tensor to gather values from.dim: The dimension along which to index.index: A tensor of indices specifying which elements to gather.
The index tensor must have the same shape as the output tensor, and its values must be valid indices along dim.
python
torch.gather(input, dim, index)Example
This example shows how to use torch.gather to select elements from a 2D tensor along dimension 1 using an index tensor.
python
import torch # Create a 2D tensor input_tensor = torch.tensor([[10, 20, 30], [40, 50, 60]]) # Indices to gather along dim=1 index_tensor = torch.tensor([[2, 1, 0], [0, 2, 1]]) # Gather elements output_tensor = torch.gather(input_tensor, dim=1, index=index_tensor) print(output_tensor)
Output
[[30 20 10]
[40 60 50]]
Common Pitfalls
Common mistakes when using torch.gather include:
- Using an
indextensor with values out of range for the specified dimension, causing runtime errors. - Mismatch in shapes between
indexand the desired output tensor. - Confusing
dimparameter, which controls the axis along which indices are applied.
Always ensure index values are valid indices for input along dim.
python
import torch input_tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) # Wrong: index values out of range (3 is invalid for dim=1 with size 3) index_wrong = torch.tensor([[0, 3, 1], [1, 0, 2]]) try: output_wrong = torch.gather(input_tensor, dim=1, index=index_wrong) except IndexError as e: print(f"Error: {e}") # Right: valid indices index_right = torch.tensor([[0, 2, 1], [1, 0, 2]]) output_right = torch.gather(input_tensor, dim=1, index=index_right) print(output_right)
Output
Error: index 3 is out of bounds for dimension 1 with size 3
tensor([[1, 3, 2],
[5, 4, 6]])
Quick Reference
Tips for using torch.gather:
- Use
dimto specify the axis along which you want to pick elements. - The
indextensor shape must match the output shape you want. - Indices in
indexmust be within the range ofinput.size(dim). - Useful for selecting elements based on dynamic indices, like in attention mechanisms or masking.
Key Takeaways
torch.gather selects elements from a tensor along a specified dimension using an index tensor.
The index tensor must have valid indices and match the output shape.
dim parameter controls the axis along which elements are gathered.
Common errors come from out-of-range indices or shape mismatches.
torch.gather is useful for advanced indexing tasks like attention or masking.