0
0
PyTorchml~5 mins

Custom Dataset class in PyTorch

Choose your learning style9 modes available
Introduction

We use a Custom Dataset class to organize and prepare data for machine learning models in PyTorch. It helps load data easily and cleanly.

When you have data in files like images or text and want to load them for training.
When your data needs special processing before feeding into a model.
When you want to use PyTorch's DataLoader to handle batches and shuffling.
When your data is not in a standard format supported by PyTorch datasets.
Syntax
PyTorch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        label = self.labels[idx]
        return sample, label

The __init__ method sets up your data.

The __len__ method returns how many samples you have.

The __getitem__ method gets one sample and its label by index.

Examples
A simple dataset with only data, no labels.
PyTorch
class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
Dataset for images with optional transformations.
PyTorch
from torchvision.io import read_image

class ImageDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = read_image(self.image_paths[idx])
        if self.transform:
            image = self.transform(image)
        label = self.labels[idx]
        return image, label
Sample Model

This program creates a simple dataset with features and labels, then uses DataLoader to get batches of size 2. It prints each batch's data and labels.

PyTorch
import torch
from torch.utils.data import Dataset, DataLoader

class SimpleDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

# Sample data
features = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]])
labels = torch.tensor([0, 1, 0, 1])

# Create dataset
dataset = SimpleDataset(features, labels)

# Use DataLoader to get batches
loader = DataLoader(dataset, batch_size=2, shuffle=True)

for batch_data, batch_labels in loader:
    print('Batch data:', batch_data)
    print('Batch labels:', batch_labels)
    print('---')
OutputSuccess
Important Notes

Always return data and label as tensors in __getitem__ for PyTorch models.

You can add data transformations inside __getitem__ if needed.

Using DataLoader with your custom dataset helps with batching and shuffling automatically.

Summary

A Custom Dataset class organizes your data for PyTorch models.

Implement __init__, __len__, and __getitem__ methods.

Use DataLoader with your dataset to handle batches and shuffling easily.