0
0
PytorchHow-ToBeginner · 4 min read

How to Load Custom Image Dataset in PyTorch Easily

To load a custom image dataset in PyTorch, create a class that inherits from torch.utils.data.Dataset and implement __len__ and __getitem__ methods. Then use torch.utils.data.DataLoader to load data in batches with optional shuffling and transformations.
📐

Syntax

To load a custom image dataset in PyTorch, you define a class that inherits from torch.utils.data.Dataset. You must implement two methods:

  • __len__(self): Returns the total number of images.
  • __getitem__(self, idx): Returns the image and label at index idx.

Then, use torch.utils.data.DataLoader to create an iterable over the dataset with options like batch size and shuffling.

python
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os

class CustomImageDataset(Dataset):
    def __init__(self, img_dir, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        self.img_labels = [f for f in os.listdir(img_dir) if f.endswith('.jpg') or f.endswith('.png')]

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels[idx])
        image = Image.open(img_path).convert('RGB')
        label = 0  # Replace with actual label logic if available
        if self.transform:
            image = self.transform(image)
        return image, label

# Usage:
# dataset = CustomImageDataset(img_dir='path/to/images', transform=some_transform)
# dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
💻

Example

This example shows how to load images from a folder, apply basic transformations, and iterate over batches using DataLoader.

python
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os

class CustomImageDataset(Dataset):
    def __init__(self, img_dir, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        self.img_labels = [f for f in os.listdir(img_dir) if f.endswith('.jpg') or f.endswith('.png')]

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels[idx])
        image = Image.open(img_path).convert('RGB')
        label = 0  # Dummy label
        if self.transform:
            image = self.transform(image)
        return image, label

# Define transformations
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor()
])

# Create dataset and dataloader
img_dir = 'sample_images'
dataset = CustomImageDataset(img_dir=img_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# Iterate over one batch
for images, labels in dataloader:
    print(f'Batch image tensor shape: {images.shape}')
    print(f'Batch labels: {labels}')
    break
Output
Batch image tensor shape: torch.Size([2, 3, 64, 64]) Batch labels: tensor([0, 0])
⚠️

Common Pitfalls

  • Not converting images to RGB can cause errors if images are grayscale or have alpha channels.
  • Forgetting to implement __len__ or __getitem__ methods will cause your dataset to fail.
  • Not applying transformations like ToTensor() will keep images as PIL objects, which PyTorch models cannot use.
  • Incorrect file path or missing images will cause runtime errors.
python
from torch.utils.data import Dataset
from PIL import Image
import os

# Wrong: Missing __len__ method
class BadDataset(Dataset):
    def __init__(self, img_dir):
        self.img_dir = img_dir
        self.img_labels = os.listdir(img_dir)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels[idx])
        image = Image.open(img_path)
        return image

# Right: Implement both methods
class GoodDataset(Dataset):
    def __init__(self, img_dir):
        self.img_dir = img_dir
        self.img_labels = os.listdir(img_dir)

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels[idx])
        image = Image.open(img_path).convert('RGB')
        return image
📊

Quick Reference

  • Dataset class: Inherit from torch.utils.data.Dataset and implement __len__ and __getitem__.
  • Image loading: Use PIL.Image.open() and convert to RGB.
  • Transformations: Use torchvision.transforms to resize, normalize, and convert images to tensors.
  • DataLoader: Wrap dataset with DataLoader for batching and shuffling.

Key Takeaways

Create a custom Dataset by inheriting from torch.utils.data.Dataset and implementing __len__ and __getitem__.
Use PIL to load images and convert them to RGB before applying transforms.
Apply torchvision transforms like ToTensor to convert images into tensors usable by PyTorch models.
Use DataLoader to batch and shuffle your dataset for efficient training.
Check file paths and image formats carefully to avoid runtime errors.