How to Use random_split in PyTorch for Dataset Splitting
Use
torch.utils.data.random_split to split a dataset into non-overlapping random subsets by specifying the dataset and lengths of each split. It returns a list of subset datasets that you can use for training, validation, or testing.Syntax
The random_split function takes two main arguments: the dataset to split and a list or tuple of lengths for each split. It returns a list of subsets corresponding to those lengths.
dataset: The full dataset to split.lengths: A list or tuple of integers specifying the sizes of each split. The sum must equal the dataset length.generator(optional): Atorch.Generatorto control randomness for reproducibility.
python
torch.utils.data.random_split(dataset, lengths, generator=None)Example
This example shows how to split a dataset of 100 samples into training (80 samples) and validation (20 samples) subsets using random_split. It prints the sizes of each subset.
python
import torch from torch.utils.data import TensorDataset, random_split # Create a dummy dataset of 100 samples data = torch.arange(100).unsqueeze(1).float() # shape (100, 1) target = torch.arange(100).float() # dummy targets dataset = TensorDataset(data, target) # Define lengths for train and validation splits train_len = 80 val_len = 20 # Split dataset randomly train_set, val_set = random_split(dataset, [train_len, val_len]) # Print sizes print(f"Train set size: {len(train_set)}") print(f"Validation set size: {len(val_set)}")
Output
Train set size: 80
Validation set size: 20
Common Pitfalls
- Sum of lengths mismatch: The sum of the lengths must exactly equal the dataset size, or
random_splitwill raise an error. - Reproducibility: Without setting a
generatorwith a fixed seed, splits will differ each run. - Dataset type: The dataset must support indexing and have a defined length.
python
import torch from torch.utils.data import TensorDataset, random_split # Dataset with 10 samples dataset = TensorDataset(torch.arange(10)) # Wrong: lengths sum to 9 instead of 10 try: splits = random_split(dataset, [5, 4]) except ValueError as e: print(f"Error: {e}") # Right: lengths sum to 10 splits = random_split(dataset, [6, 4]) print(f"Split sizes: {[len(s) for s in splits]}")
Output
Error: Sum of input lengths does not equal the length of the input dataset!
Split sizes: [6, 4]
Quick Reference
| Parameter | Description |
|---|---|
| dataset | The dataset to split (must support indexing and have length) |
| lengths | List or tuple of integers specifying sizes of splits (sum must equal dataset length) |
| generator | Optional torch.Generator for reproducible splits |
Key Takeaways
Use torch.utils.data.random_split to split datasets into random subsets by specifying lengths.
Ensure the sum of split lengths equals the dataset size to avoid errors.
Set a torch.Generator with a fixed seed for reproducible splits.
random_split returns a list of subset datasets that can be used like the original dataset.