0
0
PytorchHow-ToBeginner · 4 min read

How to Use torchvision.transforms in PyTorch for Image Processing

Use torchvision.transforms to apply common image preprocessing and augmentation steps like resizing, cropping, and normalization. Compose multiple transforms with transforms.Compose and apply them to images or datasets before training models.
📐

Syntax

The main way to use torchvision.transforms is by creating a sequence of image transformations using transforms.Compose. Each transform is a callable that modifies an image, such as resizing or converting to a tensor.

Example parts:

  • transforms.Resize(size): Resize image to given size.
  • transforms.ToTensor(): Convert image to PyTorch tensor.
  • transforms.Normalize(mean, std): Normalize tensor image with mean and std.
  • transforms.Compose([list_of_transforms]): Chain multiple transforms.
python
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
💻

Example

This example shows how to load an image, apply a composed transform to resize, convert to tensor, and normalize it. It prints the tensor shape and pixel value range.

python
from PIL import Image
from torchvision import transforms
import torch

# Define transforms
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Load image
img = Image.new('RGB', (256, 256), color='red')  # Create a red image for demo

# Apply transform
img_t = transform(img)

# Print tensor info
print(f"Tensor shape: {img_t.shape}")
print(f"Tensor min value: {img_t.min().item():.3f}")
print(f"Tensor max value: {img_t.max().item():.3f}")
Output
Tensor shape: torch.Size([3, 128, 128]) Tensor min value: -1.000 Tensor max value: 1.000
⚠️

Common Pitfalls

Common mistakes when using torchvision.transforms include:

  • Not converting images to tensors before normalization, causing errors.
  • Applying transforms in the wrong order, e.g., normalizing before converting to tensor.
  • Using transforms designed for PIL images on tensors or vice versa.
  • Forgetting to use transforms.Compose to chain multiple transforms.

Always check the expected input type for each transform.

python
from torchvision import transforms
from PIL import Image

# Wrong: Normalize before ToTensor (will cause error)
try:
    wrong_transform = transforms.Compose([
        transforms.Normalize(mean=[0.5], std=[0.5]),
        transforms.ToTensor()
    ])
    img = Image.new('L', (28, 28))
    wrong_transform(img)
except Exception as e:
    print(f"Error: {e}")

# Right order
right_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])
img = Image.new('L', (28, 28))
output = right_transform(img)
print(f"Output tensor shape: {output.shape}")
Output
Error: normalize(): expected Tensor image, got PIL Image Output tensor shape: torch.Size([1, 28, 28])
📊

Quick Reference

TransformPurposeInput TypeOutput Type
transforms.Resize(size)Resize imagePIL ImagePIL Image
transforms.CenterCrop(size)Crop center regionPIL ImagePIL Image
transforms.ToTensor()Convert to tensorPIL Image or ndarrayTensor
transforms.Normalize(mean, std)Normalize tensorTensorTensor
transforms.RandomHorizontalFlip()Random flipPIL ImagePIL Image
transforms.ColorJitter()Random color changesPIL ImagePIL Image
transforms.Compose(list)Chain transformsVariesVaries

Key Takeaways

Use transforms.Compose to chain multiple image transformations in order.
Always convert images to tensors with ToTensor before applying Normalize.
Check each transform’s expected input and output types to avoid errors.
Transforms can be used for preprocessing and data augmentation easily.
Apply transforms before feeding images into your PyTorch model.