How to Use Transforms in PyTorch for Image Preprocessing
In PyTorch, you use
transforms from torchvision.transforms to preprocess images by chaining operations like resizing, cropping, and normalization. You create a transform pipeline using transforms.Compose and apply it to images before feeding them into a model.Syntax
The basic syntax to use transforms in PyTorch involves importing transforms from torchvision, creating a pipeline with transforms.Compose, and applying it to images.
transforms.Compose([list_of_transforms]): Chains multiple transforms.- Common transforms include
Resize,CenterCrop,ToTensor, andNormalize. - Apply the composed transform to an image by calling it like a function.
python
from torchvision import transforms transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ])
Example
This example shows how to load an image, apply a transform pipeline to resize, crop, convert to tensor, and normalize it for model input.
python
from PIL import Image from torchvision import transforms import torch # Define the transform pipeline transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Load an image img = Image.new('RGB', (300, 300), color='red') # Create a red image for demo # Apply the transform img_t = transform(img) # Check the tensor shape and type print(f"Tensor shape: {img_t.shape}") print(f"Tensor type: {type(img_t)}") print(f"Tensor min/max: {img_t.min().item():.4f}/{img_t.max().item():.4f}")
Output
Tensor shape: torch.Size([3, 224, 224])
Tensor type: <class 'torch.Tensor'>
Tensor min/max: -2.1179/2.6400
Common Pitfalls
Common mistakes when using transforms include:
- Not converting images to tensor with
ToTensor()before normalization. - Applying normalization before converting to tensor causes errors.
- Using wrong mean and std values for normalization leads to poor model performance.
- For grayscale images, mean and std should have one value each.
Always apply ToTensor() before Normalize().
python
from torchvision import transforms # Wrong order (will cause error or wrong results) wrong_transform = transforms.Compose([ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), transforms.ToTensor() ]) # Correct order correct_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ])
Quick Reference
| Transform | Purpose | Example Usage |
|---|---|---|
| Resize | Change image size | transforms.Resize(256) |
| CenterCrop | Crop center region | transforms.CenterCrop(224) |
| ToTensor | Convert image to tensor | transforms.ToTensor() |
| Normalize | Scale tensor values | transforms.Normalize(mean, std) |
| RandomHorizontalFlip | Flip image randomly | transforms.RandomHorizontalFlip(p=0.5) |
| RandomRotation | Rotate image randomly | transforms.RandomRotation(30) |
Key Takeaways
Use torchvision.transforms.Compose to chain multiple image preprocessing steps.
Always convert images to tensors with ToTensor() before applying Normalize().
Common transforms include Resize, CenterCrop, ToTensor, and Normalize.
Apply transforms to images before feeding them into your PyTorch model.
Check transform order carefully to avoid errors or poor model input.