0
0
PyTorchml~5 mins

Dataset class (custom datasets) in PyTorch

Choose your learning style9 modes available
Introduction

We use custom Dataset classes to organize and load our own data easily for training machine learning models.

When you have images stored in folders and want to load them with labels.
When your data is in a CSV file and you want to feed it to a model.
When you want to apply transformations like resizing or normalization on the fly.
When you want to load data in batches efficiently during training.
When you want to combine multiple data sources into one dataset.
Syntax
PyTorch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        label = self.labels[idx]
        if self.transform:
            sample = self.transform(sample)
        return sample, label

The __init__ method sets up your data and any transforms.

The __len__ method returns how many samples you have.

The __getitem__ method returns one sample and label by index.

Examples
A simple dataset with only data, no labels.
PyTorch
class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
Dataset for images with optional transformations.
PyTorch
from torchvision.io import read_image

class ImageDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = read_image(self.image_paths[idx])
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label
Sample Model

This program creates a simple dataset with features and labels, then loads it in batches of 2 using DataLoader. It prints each batch's data and labels.

PyTorch
import torch
from torch.utils.data import Dataset, DataLoader

class SimpleDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

# Sample data
features = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]])
labels = torch.tensor([0, 1, 0, 1])

# Create dataset
dataset = SimpleDataset(features, labels)

# Create DataLoader to load data in batches
loader = DataLoader(dataset, batch_size=2, shuffle=False)

for batch_idx, (data, target) in enumerate(loader):
    print(f"Batch {batch_idx+1} data:", data)
    print(f"Batch {batch_idx+1} labels:", target)
OutputSuccess
Important Notes

Always return data and label as a tuple in __getitem__.

Use transforms to preprocess data inside the dataset.

DataLoader helps to load data in batches and shuffle if needed.

Summary

Custom Dataset classes help organize your data for PyTorch models.

Implement __init__, __len__, and __getitem__ methods.

Use DataLoader to load data in batches during training.