We use custom Dataset classes to organize and load our own data easily for training machine learning models.
0
0
Dataset class (custom datasets) in PyTorch
Introduction
When you have images stored in folders and want to load them with labels.
When your data is in a CSV file and you want to feed it to a model.
When you want to apply transformations like resizing or normalization on the fly.
When you want to load data in batches efficiently during training.
When you want to combine multiple data sources into one dataset.
Syntax
PyTorch
from torch.utils.data import Dataset class CustomDataset(Dataset): def __init__(self, data, labels, transform=None): self.data = data self.labels = labels self.transform = transform def __len__(self): return len(self.data) def __getitem__(self, idx): sample = self.data[idx] label = self.labels[idx] if self.transform: sample = self.transform(sample) return sample, label
The __init__ method sets up your data and any transforms.
The __len__ method returns how many samples you have.
The __getitem__ method returns one sample and label by index.
Examples
A simple dataset with only data, no labels.
PyTorch
class MyDataset(Dataset): def __init__(self, data): self.data = data def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx]
Dataset for images with optional transformations.
PyTorch
from torchvision.io import read_image class ImageDataset(Dataset): def __init__(self, image_paths, labels, transform=None): self.image_paths = image_paths self.labels = labels self.transform = transform def __len__(self): return len(self.image_paths) def __getitem__(self, idx): image = read_image(self.image_paths[idx]) label = self.labels[idx] if self.transform: image = self.transform(image) return image, label
Sample Model
This program creates a simple dataset with features and labels, then loads it in batches of 2 using DataLoader. It prints each batch's data and labels.
PyTorch
import torch from torch.utils.data import Dataset, DataLoader class SimpleDataset(Dataset): def __init__(self, data, labels): self.data = data self.labels = labels def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx], self.labels[idx] # Sample data features = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]]) labels = torch.tensor([0, 1, 0, 1]) # Create dataset dataset = SimpleDataset(features, labels) # Create DataLoader to load data in batches loader = DataLoader(dataset, batch_size=2, shuffle=False) for batch_idx, (data, target) in enumerate(loader): print(f"Batch {batch_idx+1} data:", data) print(f"Batch {batch_idx+1} labels:", target)
OutputSuccess
Important Notes
Always return data and label as a tuple in __getitem__.
Use transforms to preprocess data inside the dataset.
DataLoader helps to load data in batches and shuffle if needed.
Summary
Custom Dataset classes help organize your data for PyTorch models.
Implement __init__, __len__, and __getitem__ methods.
Use DataLoader to load data in batches during training.