How to Index Tensor in PyTorch: Simple Syntax and Examples
In PyTorch, you can index a tensor using square brackets
[] with integers, slices, or boolean masks. For example, tensor[0] accesses the first element, and tensor[:, 1] selects all rows in the second column. This lets you extract or modify parts of the tensor easily.Syntax
You index a PyTorch tensor using square brackets []. Inside the brackets, you can use:
- Integers: to select a specific element or row.
- Slices: like
start:stop:stepto select ranges. - Ellipsis (
...): to represent all remaining dimensions. - Boolean masks: tensors of True/False to select elements conditionally.
Example syntax:
tensor[index] # single index tensor[start:stop] # slice tensor[mask] # boolean mask
python
import torch tensor = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) print(tensor[0]) # first row print(tensor[:, 1]) # second column print(tensor[1, 2]) # element at row 1, col 2 mask = tensor > 5 print(tensor[mask]) # elements greater than 5
Output
[1 2 3]
[2 5 8]
6
tensor([6, 7, 8, 9])
Example
This example shows how to create a tensor and use different indexing methods to access elements, rows, columns, and filtered values.
python
import torch tensor = torch.tensor([[10, 20, 30], [40, 50, 60], [70, 80, 90]]) # Access first row first_row = tensor[0] # Access second column second_col = tensor[:, 1] # Access element at row 2, column 0 elem = tensor[2, 0] # Boolean mask: elements greater than 50 mask = tensor > 50 filtered = tensor[mask] print('First row:', first_row) print('Second column:', second_col) print('Element at (2,0):', elem) print('Elements > 50:', filtered)
Output
First row: tensor([10, 20, 30])
Second column: tensor([20, 50, 80])
Element at (2,0): tensor(70)
Elements > 50: tensor([60, 70, 80, 90])
Common Pitfalls
Common mistakes when indexing tensors include:
- Using Python lists instead of tensors for advanced indexing.
- Mixing up row and column order (PyTorch uses row, then column).
- Trying to index with floats or unsupported types.
- Forgetting that slicing returns views, so modifying slices changes the original tensor.
Example of wrong and right indexing:
python
import torch tensor = torch.tensor([[1, 2], [3, 4]]) # Wrong: using float index # print(tensor[0.0]) # TypeError # Wrong: mixing row/col order # print(tensor[:, 2]) # IndexError # Right: integer indices print(tensor[0, 1]) # prints 2 # Modifying slice affects original slice_ = tensor[:, 0] slice_[0] = 100 print(tensor) # tensor changed
Output
2
tensor([[100, 2],
[ 3, 4]])
Quick Reference
| Indexing Type | Description | Example |
|---|---|---|
| Integer | Selects a specific element or row | tensor[0], tensor[1,2] |
| Slice | Selects a range of elements | tensor[0:2], tensor[:, 1:] |
| Boolean Mask | Selects elements where mask is True | tensor[tensor > 5] |
| Ellipsis | Represents all remaining dimensions | tensor[..., 1] |
Key Takeaways
Use square brackets [] with integers, slices, or boolean masks to index tensors in PyTorch.
Remember PyTorch indexing order is row first, then column for 2D tensors.
Slicing returns views, so changes to slices affect the original tensor.
Boolean masks let you select elements based on conditions easily.
Avoid using floats or invalid indices to prevent errors.