How to Use Dataset in PyTorch: Simple Guide with Examples
In PyTorch,
Dataset is a class you extend to load and access your data samples and labels. You create a custom dataset by overriding __len__ to return dataset size and __getitem__ to get a data sample by index. This lets you use your data easily with PyTorch's DataLoader for batching and shuffling.Syntax
To use Dataset in PyTorch, you create a class that inherits from torch.utils.data.Dataset. You must define two methods:
__len__(self): returns the total number of samples.__getitem__(self, idx): returns the sample and label at indexidx.
This structure allows PyTorch to access your data efficiently.
python
import torch from torch.utils.data import Dataset class CustomDataset(Dataset): def __init__(self, data, labels): self.data = data self.labels = labels def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx], self.labels[idx]
Example
This example shows how to create a simple dataset of numbers and their squares, then use a DataLoader to iterate in batches.
python
import torch from torch.utils.data import Dataset, DataLoader class SquareDataset(Dataset): def __init__(self, size): self.size = size def __len__(self): return self.size def __getitem__(self, idx): x = idx y = idx * idx return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.float32) # Create dataset of size 10 dataset = SquareDataset(10) # Create DataLoader for batch size 3 loader = DataLoader(dataset, batch_size=3, shuffle=True) for batch_idx, (inputs, targets) in enumerate(loader): print(f"Batch {batch_idx + 1}") print("Inputs:", inputs) print("Targets:", targets) print()
Output
Batch 1
Inputs: tensor([3., 1., 0.])
Targets: tensor([ 9., 1., 0.])
Batch 2
Inputs: tensor([7., 8., 6.])
Targets: tensor([49., 64., 36.])
Batch 3
Inputs: tensor([9., 4., 5.])
Targets: tensor([81., 16., 25.])
Batch 4
Inputs: tensor([2.])
Targets: tensor([4.])
Common Pitfalls
Common mistakes when using Dataset include:
- Not implementing
__len__or__getitem__, causing errors. - Returning data in wrong format (e.g., not tensors) which breaks training.
- Not handling indexing properly, especially if data is stored in complex structures.
- Forgetting to shuffle data in
DataLoaderwhen needed.
python
import torch from torch.utils.data import Dataset # Wrong: missing __len__ method class BadDataset(Dataset): def __getitem__(self, idx): return idx # Right: implement both methods class GoodDataset(Dataset): def __init__(self, size): self.size = size def __len__(self): return self.size def __getitem__(self, idx): return idx
Quick Reference
- __len__: Return dataset size.
- __getitem__: Return one data sample by index.
- DataLoader: Wrap dataset for batching and shuffling.
- Always return tensors or convert data to tensors in
__getitem__.
Key Takeaways
Create a custom Dataset by subclassing torch.utils.data.Dataset and implementing __len__ and __getitem__.
Use DataLoader to batch, shuffle, and load data efficiently from your Dataset.
Always return data samples as tensors in __getitem__ for compatibility with PyTorch models.
Implement __len__ correctly to avoid errors during training or evaluation.
Shuffling data in DataLoader helps improve model training by mixing samples.