How to Shuffle Data in PyTorch DataLoader
To shuffle data in PyTorch's
DataLoader, set the parameter shuffle=True when creating the DataLoader. This ensures the data is randomly ordered each epoch, which helps improve model training by reducing bias from data order.Syntax
The DataLoader class in PyTorch has a shuffle parameter that controls whether the data is shuffled each time it is loaded. Setting shuffle=True randomizes the order of data samples.
dataset: The dataset to load data from.batch_size: Number of samples per batch.shuffle: Set toTrueto shuffle data every epoch,Falseto keep order.
python
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, ...)
Example
This example shows how to create a DataLoader with shuffled data from a simple dataset of numbers. Each epoch prints batches in a different random order.
python
import torch from torch.utils.data import DataLoader, TensorDataset # Create a simple dataset of numbers 0 to 9 data = torch.arange(10) dataset = TensorDataset(data) # DataLoader with shuffle=True loader = DataLoader(dataset, batch_size=3, shuffle=True) print('Shuffled batches:') for epoch in range(2): print(f'Epoch {epoch + 1}:') for batch in loader: print(batch[0].tolist())
Output
Shuffled batches:
Epoch 1:
[7, 0, 3]
[1, 9, 6]
[8, 5, 4]
[2]
Epoch 2:
[4, 0, 3]
[9, 7, 6]
[1, 5, 8]
[2]
Common Pitfalls
One common mistake is forgetting to set shuffle=True, which causes the data to be loaded in the same order every epoch. This can lead to poor model generalization because the model sees data in a fixed sequence.
Another pitfall is using shuffle=True with batch_size=1 in very small datasets, which might not show much difference.
python
from torch.utils.data import DataLoader, TensorDataset import torch data = torch.arange(5) dataset = TensorDataset(data) # Wrong: shuffle=False (default) loader_no_shuffle = DataLoader(dataset, batch_size=2) print('No shuffle batches:') for batch in loader_no_shuffle: print(batch[0].tolist()) # Right: shuffle=True loader_shuffle = DataLoader(dataset, batch_size=2, shuffle=True) print('Shuffle batches:') for batch in loader_shuffle: print(batch[0].tolist())
Output
No shuffle batches:
[0, 1]
[2, 3]
[4]
Shuffle batches:
[3, 0]
[4, 1]
[2]
Quick Reference
Remember these tips when shuffling data in PyTorch DataLoader:
- Use
shuffle=Trueto randomize data order each epoch. - Shuffling helps prevent model overfitting to data order.
- For validation or test sets, keep
shuffle=Falseto maintain consistent evaluation. - Shuffling works only if the dataset supports indexing.
Key Takeaways
Set shuffle=True in DataLoader to randomize data order each epoch.
Shuffling improves model training by reducing bias from fixed data order.
Keep shuffle=False for validation and test DataLoaders to ensure consistent evaluation.
Shuffling requires datasets that support indexing.
Without shuffle=True, data loads in the same order every epoch.