0
0
PyTorchml~5 mins

Data augmentation in PyTorch

Choose your learning style9 modes available
Introduction

Data augmentation helps create more training examples by changing existing data. This makes models learn better and avoid mistakes.

When you have a small number of images to train a model.
When you want your model to recognize objects from different angles or lighting.
When you want to reduce overfitting by showing varied data.
When you want to improve model accuracy without collecting new data.
Syntax
PyTorch
import torchvision.transforms as transforms

transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(30),
    transforms.ToTensor()
])

Use transforms.Compose to combine multiple augmentations.

Apply augmentations only on training data, not on validation or test data.

Examples
Flips the image horizontally with a 50% chance.
PyTorch
transforms.RandomHorizontalFlip(p=0.5)
Rotates the image randomly within ±45 degrees.
PyTorch
transforms.RandomRotation(degrees=45)
Randomly changes brightness and contrast to make images look different.
PyTorch
transforms.ColorJitter(brightness=0.2, contrast=0.2)
Crops a random part of the image and resizes it to 224x224 pixels.
PyTorch
transforms.RandomResizedCrop(size=224)
Sample Model

This code loads the CIFAR10 training data and applies random horizontal flip and rotation to each image. It then prints the shape of one batch of images and their labels.

PyTorch
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define data augmentation transforms
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(20),
    transforms.ToTensor()
])

# Load CIFAR10 training dataset with augmentation
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

# Get one batch of images and labels
images, labels = next(iter(train_loader))

print(f'Batch image tensor shape: {images.shape}')
print(f'Batch labels: {labels}')
OutputSuccess
Important Notes

Always apply the same normalization after augmentation to keep data consistent.

Augmentation increases training time but improves model generalization.

Do not apply random augmentations to validation or test sets to get fair evaluation.

Summary

Data augmentation creates new training data by changing existing data.

It helps models learn better and avoid overfitting.

Use torchvision transforms like RandomHorizontalFlip and RandomRotation in PyTorch.