How to Create a Custom Dataset in PyTorch Easily
To create a custom dataset in PyTorch, subclass
torch.utils.data.Dataset and implement the __len__ and __getitem__ methods. This lets you load and transform your data easily for training models.Syntax
To create a custom dataset, you need to subclass torch.utils.data.Dataset and define two methods:
__len__(self): Returns the total number of samples.__getitem__(self, idx): Returns the sample at indexidx.
This structure allows PyTorch to access your data like a list and use it in data loaders.
python
import torch from torch.utils.data import Dataset class CustomDataset(Dataset): def __init__(self, data): self.data = data def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx]
Example
This example shows a custom dataset that stores a list of numbers and returns their squares as samples. It demonstrates how to create the dataset, get its length, and access samples.
python
import torch from torch.utils.data import Dataset, DataLoader class SquareDataset(Dataset): def __init__(self, numbers): self.numbers = numbers def __len__(self): return len(self.numbers) def __getitem__(self, idx): x = self.numbers[idx] return x, x ** 2 # Create dataset with numbers 0 to 4 dataset = SquareDataset(list(range(5))) # Use DataLoader to iterate loader = DataLoader(dataset, batch_size=2, shuffle=False) for batch in loader: inputs, targets = batch print(f"Inputs: {inputs.tolist()}, Targets: {targets.tolist()}")
Output
Inputs: [0, 1], Targets: [0, 1]
Inputs: [2, 3], Targets: [4, 9]
Inputs: [4], Targets: [16]
Common Pitfalls
Common mistakes when creating custom datasets include:
- Not implementing
__len__or__getitem__, which causes errors. - Returning data in wrong format (should be tensors or convertible to tensors).
- Not handling index errors if
idxis out of range. - Forgetting to apply necessary data transformations inside
__getitem__.
Always test your dataset by accessing some samples before training.
python
import torch from torch.utils.data import Dataset # Wrong: Missing __len__ method class BadDataset(Dataset): def __init__(self, data): self.data = data def __getitem__(self, idx): return self.data[idx] # Right: Implement both methods class GoodDataset(Dataset): def __init__(self, data): self.data = data def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx]
Quick Reference
Custom Dataset Creation Steps:
- Subclass
torch.utils.data.Dataset - Define
__init__to load or receive data - Define
__len__to return dataset size - Define
__getitem__to return one sample by index - Use
DataLoaderto batch and shuffle data
Key Takeaways
Subclass torch.utils.data.Dataset and implement __len__ and __getitem__ to create a custom dataset.
Use __getitem__ to load and optionally transform each data sample.
Test your dataset by accessing samples before using it in training.
Use DataLoader to batch and shuffle your custom dataset easily.
Avoid missing methods or wrong data formats to prevent runtime errors.