0
0
PytorchHow-ToBeginner · 4 min read

How to Use CIFAR10 Dataset in PyTorch: Simple Guide

Use torchvision.datasets.CIFAR10 to load the CIFAR10 dataset in PyTorch. You can specify train=True or False to get training or test data, and use torch.utils.data.DataLoader to iterate over the data in batches.
📐

Syntax

The CIFAR10 dataset in PyTorch is loaded using torchvision.datasets.CIFAR10. You need to specify the root directory to store data, train to choose training or test set, download to download if not present, and transform to apply data transformations.

Use torch.utils.data.DataLoader to create an iterable over the dataset with options like batch size and shuffling.

python
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torch.utils.data import DataLoader

# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load training dataset
train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)

# Create data loader
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
💻

Example

This example loads the CIFAR10 training data, applies normalization, and prints the shape of one batch of images and labels.

python
import torch
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torch.utils.data import DataLoader

# Define transformations to convert images to tensors and normalize
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load CIFAR10 training dataset
train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)

# Create DataLoader for batch processing
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# Get one batch of data
images, labels = next(iter(train_loader))

# Print shapes
print(f'Batch image tensor shape: {images.shape}')
print(f'Batch labels tensor shape: {labels.shape}')
Output
Batch image tensor shape: torch.Size([64, 3, 32, 32]) Batch labels tensor shape: torch.Size([64])
⚠️

Common Pitfalls

  • Forgetting to apply transforms.ToTensor() causes errors because raw images are PIL images, not tensors.
  • Not normalizing images can slow down training or reduce accuracy.
  • Setting download=True every time is unnecessary after the first download.
  • Using too large batch sizes may cause out-of-memory errors.
python
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torch.utils.data import DataLoader

# Wrong: No transform, images are PIL images
train_dataset_wrong = CIFAR10(root='./data', train=True, download=False)
train_loader_wrong = DataLoader(train_dataset_wrong, batch_size=64, shuffle=True)

# Right: Apply ToTensor and Normalize
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset_right = CIFAR10(root='./data', train=True, download=False, transform=transform)
train_loader_right = DataLoader(train_dataset_right, batch_size=64, shuffle=True)
📊

Quick Reference

Here is a quick summary of key parameters when using CIFAR10 dataset in PyTorch:

ParameterDescriptionExample
rootFolder to store/download data'./data'
trainLoad training set if True, else test setTrue or False
downloadDownload dataset if not presentTrue or False
transformTransformations applied to imagestransforms.Compose([...])
batch_sizeNumber of samples per batch in DataLoader64
shuffleShuffle data each epoch in DataLoaderTrue or False

Key Takeaways

Use torchvision.datasets.CIFAR10 with train and transform parameters to load data.
Always apply transforms.ToTensor() and normalization for proper training.
Use DataLoader to batch and shuffle data efficiently.
Set download=True only once to avoid repeated downloads.
Watch batch size to prevent memory issues during training.