0
0
PyTorchml~5 mins

Train/val/test split in PyTorch

Choose your learning style9 modes available
Introduction
Splitting data into train, validation, and test sets helps us build models that learn well and check how good they are on new data.
When you want to teach a model and check its learning progress.
When you want to tune model settings without cheating by looking at test data.
When you want to see how well your model works on data it has never seen.
When you have a dataset and want to avoid overfitting.
When you want to compare different models fairly.
Syntax
PyTorch
from sklearn.model_selection import train_test_split

# Split data into train and temp (val+test)
train_data, temp_data = train_test_split(data, test_size=0.4, random_state=42)

# Split temp_data into validation and test
val_data, test_data = train_test_split(temp_data, test_size=0.5, random_state=42)
We first split data into training and a temporary set (which will be split again).
Then we split the temporary set equally into validation and test sets.
Examples
Splits data into 70% train, 15% val, and 15% test.
PyTorch
train_data, temp_data = train_test_split(data, test_size=0.3)
val_data, test_data = train_test_split(temp_data, test_size=0.5)
Splits data into 80% train and 20% validation only.
PyTorch
train_data, val_data = train_test_split(data, test_size=0.2)
# No separate test set here
Sample Model
This code creates a simple dataset and splits it into train, validation, and test sets using PyTorch's random_split. It then prints the number of samples in each set.
PyTorch
import torch
from torch.utils.data import DataLoader, random_split, TensorDataset

# Create dummy dataset of 100 samples with 5 features
X = torch.randn(100, 5)
y = torch.randint(0, 2, (100,))

# Combine features and labels into a TensorDataset
dataset = TensorDataset(X, y)

# Define split sizes: 70 train, 15 val, 15 test
train_size = 70
val_size = 15
test_size = 15

# Use random_split to split dataset
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(42))

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=10)
val_loader = DataLoader(val_dataset, batch_size=10)
test_loader = DataLoader(test_dataset, batch_size=10)

# Print sizes to confirm
print(f"Train samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")
OutputSuccess
Important Notes
Always set a random seed to get the same split every time you run your code.
Validation data helps tune your model without touching test data.
Test data should only be used once at the end to check final model performance.
Summary
Split your data into three parts: train, validation, and test.
Train data teaches the model, validation data tunes it, and test data checks it.
Use tools like sklearn's train_test_split or PyTorch's random_split for easy splitting.