0
0
PytorchHow-ToBeginner · 3 min read

How to Use random_split in PyTorch for Dataset Splitting

Use torch.utils.data.random_split to split a dataset into non-overlapping random subsets by specifying the dataset and lengths of each split. It returns a list of subset datasets that you can use for training, validation, or testing.
📐

Syntax

The random_split function takes two main arguments: the dataset to split and a list or tuple of lengths for each split. It returns a list of subsets corresponding to those lengths.

  • dataset: The full dataset to split.
  • lengths: A list or tuple of integers specifying the sizes of each split. The sum must equal the dataset length.
  • generator (optional): A torch.Generator to control randomness for reproducibility.
python
torch.utils.data.random_split(dataset, lengths, generator=None)
💻

Example

This example shows how to split a dataset of 100 samples into training (80 samples) and validation (20 samples) subsets using random_split. It prints the sizes of each subset.

python
import torch
from torch.utils.data import TensorDataset, random_split

# Create a dummy dataset of 100 samples
data = torch.arange(100).unsqueeze(1).float()  # shape (100, 1)
target = torch.arange(100).float()  # dummy targets

dataset = TensorDataset(data, target)

# Define lengths for train and validation splits
train_len = 80
val_len = 20

# Split dataset randomly
train_set, val_set = random_split(dataset, [train_len, val_len])

# Print sizes
print(f"Train set size: {len(train_set)}")
print(f"Validation set size: {len(val_set)}")
Output
Train set size: 80 Validation set size: 20
⚠️

Common Pitfalls

  • Sum of lengths mismatch: The sum of the lengths must exactly equal the dataset size, or random_split will raise an error.
  • Reproducibility: Without setting a generator with a fixed seed, splits will differ each run.
  • Dataset type: The dataset must support indexing and have a defined length.
python
import torch
from torch.utils.data import TensorDataset, random_split

# Dataset with 10 samples
dataset = TensorDataset(torch.arange(10))

# Wrong: lengths sum to 9 instead of 10
try:
    splits = random_split(dataset, [5, 4])
except ValueError as e:
    print(f"Error: {e}")

# Right: lengths sum to 10
splits = random_split(dataset, [6, 4])
print(f"Split sizes: {[len(s) for s in splits]}")
Output
Error: Sum of input lengths does not equal the length of the input dataset! Split sizes: [6, 4]
📊

Quick Reference

ParameterDescription
datasetThe dataset to split (must support indexing and have length)
lengthsList or tuple of integers specifying sizes of splits (sum must equal dataset length)
generatorOptional torch.Generator for reproducible splits

Key Takeaways

Use torch.utils.data.random_split to split datasets into random subsets by specifying lengths.
Ensure the sum of split lengths equals the dataset size to avoid errors.
Set a torch.Generator with a fixed seed for reproducible splits.
random_split returns a list of subset datasets that can be used like the original dataset.