0
0
PytorchHow-ToBeginner · 4 min read

How to Use Dataset in PyTorch: Simple Guide with Examples

In PyTorch, Dataset is a class you extend to load and access your data samples and labels. You create a custom dataset by overriding __len__ to return dataset size and __getitem__ to get a data sample by index. This lets you use your data easily with PyTorch's DataLoader for batching and shuffling.
📐

Syntax

To use Dataset in PyTorch, you create a class that inherits from torch.utils.data.Dataset. You must define two methods:

  • __len__(self): returns the total number of samples.
  • __getitem__(self, idx): returns the sample and label at index idx.

This structure allows PyTorch to access your data efficiently.

python
import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]
💻

Example

This example shows how to create a simple dataset of numbers and their squares, then use a DataLoader to iterate in batches.

python
import torch
from torch.utils.data import Dataset, DataLoader

class SquareDataset(Dataset):
    def __init__(self, size):
        self.size = size

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        x = idx
        y = idx * idx
        return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.float32)

# Create dataset of size 10
dataset = SquareDataset(10)

# Create DataLoader for batch size 3
loader = DataLoader(dataset, batch_size=3, shuffle=True)

for batch_idx, (inputs, targets) in enumerate(loader):
    print(f"Batch {batch_idx + 1}")
    print("Inputs:", inputs)
    print("Targets:", targets)
    print()
Output
Batch 1 Inputs: tensor([3., 1., 0.]) Targets: tensor([ 9., 1., 0.]) Batch 2 Inputs: tensor([7., 8., 6.]) Targets: tensor([49., 64., 36.]) Batch 3 Inputs: tensor([9., 4., 5.]) Targets: tensor([81., 16., 25.]) Batch 4 Inputs: tensor([2.]) Targets: tensor([4.])
⚠️

Common Pitfalls

Common mistakes when using Dataset include:

  • Not implementing __len__ or __getitem__, causing errors.
  • Returning data in wrong format (e.g., not tensors) which breaks training.
  • Not handling indexing properly, especially if data is stored in complex structures.
  • Forgetting to shuffle data in DataLoader when needed.
python
import torch
from torch.utils.data import Dataset

# Wrong: missing __len__ method
class BadDataset(Dataset):
    def __getitem__(self, idx):
        return idx

# Right: implement both methods
class GoodDataset(Dataset):
    def __init__(self, size):
        self.size = size

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        return idx
📊

Quick Reference

  • __len__: Return dataset size.
  • __getitem__: Return one data sample by index.
  • DataLoader: Wrap dataset for batching and shuffling.
  • Always return tensors or convert data to tensors in __getitem__.

Key Takeaways

Create a custom Dataset by subclassing torch.utils.data.Dataset and implementing __len__ and __getitem__.
Use DataLoader to batch, shuffle, and load data efficiently from your Dataset.
Always return data samples as tensors in __getitem__ for compatibility with PyTorch models.
Implement __len__ correctly to avoid errors during training or evaluation.
Shuffling data in DataLoader helps improve model training by mixing samples.