0
0
PyTorchml~15 mins

Indexing and slicing in PyTorch - ML Experiment: Train & Evaluate

Choose your learning style9 modes available
Experiment - Indexing and slicing
Problem:You have a 3D tensor representing a batch of images with shape (batch_size, height, width). You want to extract specific parts of the images using indexing and slicing.
Current Metrics:N/A (This is a data manipulation task, not a model training task.)
Issue:You are unsure how to correctly use PyTorch indexing and slicing to extract sub-tensors without errors.
Your Task
Extract the center 2x2 patch from each image in the batch using PyTorch indexing and slicing.
Do not use any loops; use only tensor indexing and slicing.
The input tensor shape is (4, 5, 5) representing 4 images of 5x5 pixels.
Hint 1
Hint 2
Hint 3
Solution
PyTorch
import torch

# Create a batch of 4 images, each 5x5 pixels
images = torch.arange(4*5*5).reshape(4, 5, 5)

# Extract the center 2x2 patch from each image
# Center indices for 5x5 are rows 1 and 2 (0-based), columns 1 and 2
center_patch = images[:, 1:3, 1:3]

print('Original images shape:', images.shape)
print('Center patches shape:', center_patch.shape)
print('Center patches tensor:')
print(center_patch)
Used slicing to select rows 1 to 2 and columns 1 to 2 for all images in the batch.
Used ':' to select all images in the batch dimension.
Reshaped the tensor to simulate a batch of images.
Results Interpretation

Before slicing, the tensor shape was (4, 5, 5), representing 4 images of size 5x5.

After slicing, the tensor shape is (4, 2, 2), representing the center 2x2 patch of each image.

Indexing and slicing in PyTorch allow you to efficiently extract specific parts of tensors without loops, which is essential for data preprocessing and model input preparation.
Bonus Experiment
Try extracting the bottom-right 3x3 patch from each image using indexing and slicing.
💡 Hint
Use negative indices or calculate the start index as height - 3 and width - 3 to slice the last 3 rows and columns.