How to Slice Tensor in PyTorch: Syntax and Examples
In PyTorch, you slice a tensor using standard Python slicing syntax inside square brackets, like
tensor[start:stop:step]. You can slice along any dimension by specifying indices for each dimension separated by commas, for example, tensor[:, 1:4] slices all rows and columns 1 to 3.Syntax
PyTorch tensor slicing uses Python's standard slice notation inside square brackets. You can specify start, stop, and step for each dimension separated by commas.
tensor[start:stop:step]slices one dimension.tensor[dim1_slice, dim2_slice, ...]slices multiple dimensions.- Omitting
startmeans from the beginning, omittingstopmeans until the end, and omittingstepmeans step of 1.
python
tensor[start:stop:step] tensor[dim1_slice, dim2_slice, ...]
Example
This example shows how to slice a 2D tensor to get specific rows and columns using PyTorch slicing syntax.
python
import torch # Create a 3x5 tensor with values from 0 to 14 tensor = torch.arange(15).reshape(3, 5) # Slice: all rows, columns 1 to 3 (index 1 to 4 exclusive) sliced = tensor[:, 1:4] print("Original tensor:\n", tensor) print("Sliced tensor (all rows, cols 1 to 3):\n", sliced)
Output
Original tensor:
tensor([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14]])
Sliced tensor (all rows, cols 1 to 3):
tensor([[ 1, 2, 3],
[ 6, 7, 8],
[11, 12, 13]])
Common Pitfalls
Common mistakes when slicing tensors include:
- Using out-of-range indices causing empty slices or errors.
- Confusing inclusive vs exclusive end index (Python slices exclude the stop index).
- Forgetting to use commas to separate slices for multiple dimensions.
- Modifying slices expecting original tensor to change (slices are views, but some operations create copies).
python
import torch tensor = torch.arange(10) # Wrong: missing comma for multi-dim slice (will error if tensor is 2D) # sliced_wrong = tensor[1:5 2:7] # SyntaxError # Right: use comma to separate dimensions # For 2D tensor example: tensor_2d = torch.arange(20).reshape(4,5) sliced_right = tensor_2d[1:3, 2:5] print(sliced_right)
Output
tensor([[ 7, 8, 9],
[12, 13, 14]])
Quick Reference
| Syntax | Description |
|---|---|
| tensor[start:stop] | Slice from start index up to but not including stop index |
| tensor[start:stop:step] | Slice with step size (e.g., every 2nd element) |
| tensor[:, col_start:col_stop] | Slice all rows and specific columns |
| tensor[row_start:row_stop, :] | Slice specific rows and all columns |
| tensor[dim1_slice, dim2_slice, ...] | Slice multiple dimensions with comma-separated slices |
Key Takeaways
Use Python slice syntax inside square brackets to slice PyTorch tensors.
Separate slices for each dimension with commas to slice multi-dimensional tensors.
Remember Python slices exclude the stop index, so stop is not included in the slice.
Slicing returns a view, so changes to the slice may affect the original tensor.
Avoid out-of-range indices to prevent empty slices or errors.