How to Split Dataset in PyTorch: Simple Guide with Examples
In PyTorch, you can split a dataset using
torch.utils.data.random_split to create random train and test subsets. Alternatively, use SubsetRandomSampler with DataLoader for custom splits.Syntax
torch.utils.data.random_split splits a dataset into non-overlapping new datasets of given lengths.
dataset: The original dataset to split.lengths: List or tuple of lengths for each split.- Returns a list of subsets.
SubsetRandomSampler allows sampling elements randomly from a list of indices, useful for DataLoader.
python
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size]) sampler = torch.utils.data.SubsetRandomSampler(indices) dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size)
Example
This example shows how to split a dataset into 80% training and 20% testing using random_split. It prints the sizes of each split.
python
import torch from torch.utils.data import random_split, DataLoader from torchvision.datasets import FakeData from torchvision.transforms import ToTensor # Create a fake dataset of 100 samples full_dataset = FakeData(size=100, transform=ToTensor()) # Define split sizes train_size = int(0.8 * len(full_dataset)) test_size = len(full_dataset) - train_size # Split dataset train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size]) # Create DataLoaders train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=10, shuffle=False) print(f"Train dataset size: {len(train_dataset)}") print(f"Test dataset size: {len(test_dataset)}")
Output
Train dataset size: 80
Test dataset size: 20
Common Pitfalls
- Not setting the correct lengths for splits can cause errors or unexpected sizes.
- Using
shuffle=Falsein DataLoader without shuffling indices may lead to biased splits. - For reproducibility, set a random seed before splitting.
Example of a common mistake and fix:
python
# Wrong: lengths do not sum to dataset size # train_dataset, test_dataset = random_split(full_dataset, [70, 40]) # 70 + 40 != 100 # Right: lengths sum to dataset size train_dataset, test_dataset = random_split(full_dataset, [70, 30])
Quick Reference
| Method | Description | Usage |
|---|---|---|
| random_split | Splits dataset into random subsets by lengths | train_ds, test_ds = random_split(dataset, [train_len, test_len]) |
| SubsetRandomSampler | Samples elements randomly from given indices | sampler = SubsetRandomSampler(indices) dataloader = DataLoader(dataset, sampler=sampler) |
| DataLoader shuffle | Shuffles data each epoch (only for full dataset) | DataLoader(dataset, shuffle=True) |
Key Takeaways
Use torch.utils.data.random_split to easily split datasets by specifying lengths.
Ensure split lengths sum exactly to the dataset size to avoid errors.
Set a random seed for reproducible splits when using random_split.
SubsetRandomSampler allows custom index-based sampling for flexible splits.
Use DataLoader with shuffle=True only when you want to shuffle the entire dataset.