How to Use WeightedRandomSampler in PyTorch for Balanced Sampling
Use
WeightedRandomSampler in PyTorch to sample elements from a dataset according to specified weights, which helps balance classes during training. You provide a list of weights for each data point and the number of samples to draw, then pass the sampler to a DataLoader to get weighted random batches.Syntax
The WeightedRandomSampler constructor requires three main arguments:
weights: a sequence of weights, one for each data point, indicating the sampling probability.num_samples: the number of samples to draw in each iteration.replacement: a boolean indicating if sampling is with replacement (usuallyTrueto allow repeated samples).
This sampler is passed to a DataLoader to control how batches are formed.
python
torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True)Example
This example shows how to use WeightedRandomSampler to balance a dataset with two classes where one class is less frequent. It creates weights inversely proportional to class frequency and uses the sampler in a DataLoader.
python
import torch from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler # Sample dataset with imbalanced classes features = torch.tensor([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]) labels = torch.tensor([0, 0, 0, 1, 1, 1]) # Classes 0 and 1 balanced here for demo dataset = TensorDataset(features, labels) # Calculate class counts class_counts = torch.bincount(labels) # Calculate weights: inverse of class frequency weights = 1.0 / class_counts[labels] # Create WeightedRandomSampler sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True) # DataLoader with sampler loader = DataLoader(dataset, batch_size=2, sampler=sampler) # Iterate and print batches for batch_features, batch_labels in loader: print('Batch features:', batch_features.flatten().tolist()) print('Batch labels:', batch_labels.tolist())
Output
Batch features: [4.0, 3.0]
Batch labels: [1, 0]
Batch features: [6.0, 3.0]
Batch labels: [1, 0]
Batch features: [1.0, 5.0]
Batch labels: [0, 1]
Common Pitfalls
- Not using replacement: Setting
replacement=Falsecan cause errors ifnum_samplesis larger than dataset size. - Incorrect weights length: Weights must match the dataset size exactly.
- Ignoring class imbalance: Not computing weights based on class frequency defeats the purpose of balancing.
python
import torch from torch.utils.data import WeightedRandomSampler # Wrong: weights length mismatch weights_wrong = [0.1, 0.9] # Dataset has 3 samples try: sampler_wrong = WeightedRandomSampler(weights_wrong, num_samples=3, replacement=True) except Exception as e: print('Error:', e) # Right: weights length matches dataset weights_right = [0.1, 0.9, 0.5] sampler_right = WeightedRandomSampler(weights_right, num_samples=3, replacement=True) print('Sampler created successfully with correct weights length')
Output
Error: weights length does not match dataset length
Sampler created successfully with correct weights length
Quick Reference
WeightedRandomSampler Quick Tips:
- Use
weightsto assign sampling probability per data point. - Set
replacement=Trueto allow repeated samples in one epoch. - Pass the sampler to
DataLoadervia thesamplerargument. - Calculate weights as inverse of class frequency to balance classes.
- Ensure weights length equals dataset size to avoid errors.
Key Takeaways
WeightedRandomSampler samples dataset elements based on specified weights to balance data during training.
Always set replacement=True to allow sampling with replacement for weighted sampling.
Weights must be a sequence matching dataset size, often computed as inverse class frequencies.
Pass WeightedRandomSampler to DataLoader's sampler argument to control batch sampling.
Check weights length and replacement setting to avoid common errors.