How to Use torchvision.transforms in PyTorch for Image Processing
Use
torchvision.transforms to apply common image preprocessing and augmentation steps like resizing, cropping, and normalization. Compose multiple transforms with transforms.Compose and apply them to images or datasets before training models.Syntax
The main way to use torchvision.transforms is by creating a sequence of image transformations using transforms.Compose. Each transform is a callable that modifies an image, such as resizing or converting to a tensor.
Example parts:
transforms.Resize(size): Resize image to given size.transforms.ToTensor(): Convert image to PyTorch tensor.transforms.Normalize(mean, std): Normalize tensor image with mean and std.transforms.Compose([list_of_transforms]): Chain multiple transforms.
python
from torchvision import transforms transform = transforms.Compose([ transforms.Resize((128, 128)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ])
Example
This example shows how to load an image, apply a composed transform to resize, convert to tensor, and normalize it. It prints the tensor shape and pixel value range.
python
from PIL import Image from torchvision import transforms import torch # Define transforms transform = transforms.Compose([ transforms.Resize((128, 128)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) # Load image img = Image.new('RGB', (256, 256), color='red') # Create a red image for demo # Apply transform img_t = transform(img) # Print tensor info print(f"Tensor shape: {img_t.shape}") print(f"Tensor min value: {img_t.min().item():.3f}") print(f"Tensor max value: {img_t.max().item():.3f}")
Output
Tensor shape: torch.Size([3, 128, 128])
Tensor min value: -1.000
Tensor max value: 1.000
Common Pitfalls
Common mistakes when using torchvision.transforms include:
- Not converting images to tensors before normalization, causing errors.
- Applying transforms in the wrong order, e.g., normalizing before converting to tensor.
- Using transforms designed for PIL images on tensors or vice versa.
- Forgetting to use
transforms.Composeto chain multiple transforms.
Always check the expected input type for each transform.
python
from torchvision import transforms from PIL import Image # Wrong: Normalize before ToTensor (will cause error) try: wrong_transform = transforms.Compose([ transforms.Normalize(mean=[0.5], std=[0.5]), transforms.ToTensor() ]) img = Image.new('L', (28, 28)) wrong_transform(img) except Exception as e: print(f"Error: {e}") # Right order right_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5]) ]) img = Image.new('L', (28, 28)) output = right_transform(img) print(f"Output tensor shape: {output.shape}")
Output
Error: normalize(): expected Tensor image, got PIL Image
Output tensor shape: torch.Size([1, 28, 28])
Quick Reference
| Transform | Purpose | Input Type | Output Type |
|---|---|---|---|
| transforms.Resize(size) | Resize image | PIL Image | PIL Image |
| transforms.CenterCrop(size) | Crop center region | PIL Image | PIL Image |
| transforms.ToTensor() | Convert to tensor | PIL Image or ndarray | Tensor |
| transforms.Normalize(mean, std) | Normalize tensor | Tensor | Tensor |
| transforms.RandomHorizontalFlip() | Random flip | PIL Image | PIL Image |
| transforms.ColorJitter() | Random color changes | PIL Image | PIL Image |
| transforms.Compose(list) | Chain transforms | Varies | Varies |
Key Takeaways
Use transforms.Compose to chain multiple image transformations in order.
Always convert images to tensors with ToTensor before applying Normalize.
Check each transform’s expected input and output types to avoid errors.
Transforms can be used for preprocessing and data augmentation easily.
Apply transforms before feeding images into your PyTorch model.