0
0
PytorchHow-ToBeginner · 3 min read

How to Index Tensor in PyTorch: Simple Syntax and Examples

In PyTorch, you can index a tensor using square brackets [] with integers, slices, or boolean masks. For example, tensor[0] accesses the first element, and tensor[:, 1] selects all rows in the second column. This lets you extract or modify parts of the tensor easily.
📐

Syntax

You index a PyTorch tensor using square brackets []. Inside the brackets, you can use:

  • Integers: to select a specific element or row.
  • Slices: like start:stop:step to select ranges.
  • Ellipsis (...): to represent all remaining dimensions.
  • Boolean masks: tensors of True/False to select elements conditionally.

Example syntax:

tensor[index]  # single index
tensor[start:stop]  # slice
tensor[mask]  # boolean mask
python
import torch
tensor = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print(tensor[0])        # first row
print(tensor[:, 1])     # second column
print(tensor[1, 2])     # element at row 1, col 2
mask = tensor > 5
print(tensor[mask])     # elements greater than 5
Output
[1 2 3] [2 5 8] 6 tensor([6, 7, 8, 9])
💻

Example

This example shows how to create a tensor and use different indexing methods to access elements, rows, columns, and filtered values.

python
import torch

tensor = torch.tensor([[10, 20, 30],
                       [40, 50, 60],
                       [70, 80, 90]])

# Access first row
first_row = tensor[0]

# Access second column
second_col = tensor[:, 1]

# Access element at row 2, column 0
elem = tensor[2, 0]

# Boolean mask: elements greater than 50
mask = tensor > 50
filtered = tensor[mask]

print('First row:', first_row)
print('Second column:', second_col)
print('Element at (2,0):', elem)
print('Elements > 50:', filtered)
Output
First row: tensor([10, 20, 30]) Second column: tensor([20, 50, 80]) Element at (2,0): tensor(70) Elements > 50: tensor([60, 70, 80, 90])
⚠️

Common Pitfalls

Common mistakes when indexing tensors include:

  • Using Python lists instead of tensors for advanced indexing.
  • Mixing up row and column order (PyTorch uses row, then column).
  • Trying to index with floats or unsupported types.
  • Forgetting that slicing returns views, so modifying slices changes the original tensor.

Example of wrong and right indexing:

python
import torch

tensor = torch.tensor([[1, 2], [3, 4]])

# Wrong: using float index
# print(tensor[0.0])  # TypeError

# Wrong: mixing row/col order
# print(tensor[:, 2])  # IndexError

# Right: integer indices
print(tensor[0, 1])  # prints 2

# Modifying slice affects original
slice_ = tensor[:, 0]
slice_[0] = 100
print(tensor)  # tensor changed
Output
2 tensor([[100, 2], [ 3, 4]])
📊

Quick Reference

Indexing TypeDescriptionExample
IntegerSelects a specific element or rowtensor[0], tensor[1,2]
SliceSelects a range of elementstensor[0:2], tensor[:, 1:]
Boolean MaskSelects elements where mask is Truetensor[tensor > 5]
EllipsisRepresents all remaining dimensionstensor[..., 1]

Key Takeaways

Use square brackets [] with integers, slices, or boolean masks to index tensors in PyTorch.
Remember PyTorch indexing order is row first, then column for 2D tensors.
Slicing returns views, so changes to slices affect the original tensor.
Boolean masks let you select elements based on conditions easily.
Avoid using floats or invalid indices to prevent errors.