We use a Custom Dataset class to organize and prepare data for machine learning models in PyTorch. It helps load data easily and cleanly.
0
0
Custom Dataset class in PyTorch
Introduction
When you have data in files like images or text and want to load them for training.
When your data needs special processing before feeding into a model.
When you want to use PyTorch's DataLoader to handle batches and shuffling.
When your data is not in a standard format supported by PyTorch datasets.
Syntax
PyTorch
from torch.utils.data import Dataset class CustomDataset(Dataset): def __init__(self, data, labels): self.data = data self.labels = labels def __len__(self): return len(self.data) def __getitem__(self, idx): sample = self.data[idx] label = self.labels[idx] return sample, label
The __init__ method sets up your data.
The __len__ method returns how many samples you have.
The __getitem__ method gets one sample and its label by index.
Examples
A simple dataset with only data, no labels.
PyTorch
class MyDataset(Dataset): def __init__(self, data): self.data = data def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx]
Dataset for images with optional transformations.
PyTorch
from torchvision.io import read_image class ImageDataset(Dataset): def __init__(self, image_paths, labels, transform=None): self.image_paths = image_paths self.labels = labels self.transform = transform def __len__(self): return len(self.image_paths) def __getitem__(self, idx): image = read_image(self.image_paths[idx]) if self.transform: image = self.transform(image) label = self.labels[idx] return image, label
Sample Model
This program creates a simple dataset with features and labels, then uses DataLoader to get batches of size 2. It prints each batch's data and labels.
PyTorch
import torch from torch.utils.data import Dataset, DataLoader class SimpleDataset(Dataset): def __init__(self, data, labels): self.data = data self.labels = labels def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx], self.labels[idx] # Sample data features = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]]) labels = torch.tensor([0, 1, 0, 1]) # Create dataset dataset = SimpleDataset(features, labels) # Use DataLoader to get batches loader = DataLoader(dataset, batch_size=2, shuffle=True) for batch_data, batch_labels in loader: print('Batch data:', batch_data) print('Batch labels:', batch_labels) print('---')
OutputSuccess
Important Notes
Always return data and label as tensors in __getitem__ for PyTorch models.
You can add data transformations inside __getitem__ if needed.
Using DataLoader with your custom dataset helps with batching and shuffling automatically.
Summary
A Custom Dataset class organizes your data for PyTorch models.
Implement __init__, __len__, and __getitem__ methods.
Use DataLoader with your dataset to handle batches and shuffling easily.