0
0
PytorchHow-ToBeginner · 3 min read

How to Split Dataset in PyTorch: Simple Guide with Examples

In PyTorch, you can split a dataset using torch.utils.data.random_split to create random train and test subsets. Alternatively, use SubsetRandomSampler with DataLoader for custom splits.
📐

Syntax

torch.utils.data.random_split splits a dataset into non-overlapping new datasets of given lengths.

  • dataset: The original dataset to split.
  • lengths: List or tuple of lengths for each split.
  • Returns a list of subsets.

SubsetRandomSampler allows sampling elements randomly from a list of indices, useful for DataLoader.

python
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

sampler = torch.utils.data.SubsetRandomSampler(indices)
dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size)
💻

Example

This example shows how to split a dataset into 80% training and 20% testing using random_split. It prints the sizes of each split.

python
import torch
from torch.utils.data import random_split, DataLoader
from torchvision.datasets import FakeData
from torchvision.transforms import ToTensor

# Create a fake dataset of 100 samples
full_dataset = FakeData(size=100, transform=ToTensor())

# Define split sizes
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size

# Split dataset
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=10, shuffle=False)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")
Output
Train dataset size: 80 Test dataset size: 20
⚠️

Common Pitfalls

  • Not setting the correct lengths for splits can cause errors or unexpected sizes.
  • Using shuffle=False in DataLoader without shuffling indices may lead to biased splits.
  • For reproducibility, set a random seed before splitting.

Example of a common mistake and fix:

python
# Wrong: lengths do not sum to dataset size
# train_dataset, test_dataset = random_split(full_dataset, [70, 40])  # 70 + 40 != 100

# Right: lengths sum to dataset size
train_dataset, test_dataset = random_split(full_dataset, [70, 30])
📊

Quick Reference

MethodDescriptionUsage
random_splitSplits dataset into random subsets by lengthstrain_ds, test_ds = random_split(dataset, [train_len, test_len])
SubsetRandomSamplerSamples elements randomly from given indicessampler = SubsetRandomSampler(indices) dataloader = DataLoader(dataset, sampler=sampler)
DataLoader shuffleShuffles data each epoch (only for full dataset)DataLoader(dataset, shuffle=True)

Key Takeaways

Use torch.utils.data.random_split to easily split datasets by specifying lengths.
Ensure split lengths sum exactly to the dataset size to avoid errors.
Set a random seed for reproducible splits when using random_split.
SubsetRandomSampler allows custom index-based sampling for flexible splits.
Use DataLoader with shuffle=True only when you want to shuffle the entire dataset.