0
0
PytorchHow-ToBeginner · 3 min read

How to Use WeightedRandomSampler in PyTorch for Balanced Sampling

Use WeightedRandomSampler in PyTorch to sample elements from a dataset according to specified weights, which helps balance classes during training. You provide a list of weights for each data point and the number of samples to draw, then pass the sampler to a DataLoader to get weighted random batches.
📐

Syntax

The WeightedRandomSampler constructor requires three main arguments:

  • weights: a sequence of weights, one for each data point, indicating the sampling probability.
  • num_samples: the number of samples to draw in each iteration.
  • replacement: a boolean indicating if sampling is with replacement (usually True to allow repeated samples).

This sampler is passed to a DataLoader to control how batches are formed.

python
torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True)
💻

Example

This example shows how to use WeightedRandomSampler to balance a dataset with two classes where one class is less frequent. It creates weights inversely proportional to class frequency and uses the sampler in a DataLoader.

python
import torch
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler

# Sample dataset with imbalanced classes
features = torch.tensor([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]])
labels = torch.tensor([0, 0, 0, 1, 1, 1])  # Classes 0 and 1 balanced here for demo

dataset = TensorDataset(features, labels)

# Calculate class counts
class_counts = torch.bincount(labels)

# Calculate weights: inverse of class frequency
weights = 1.0 / class_counts[labels]

# Create WeightedRandomSampler
sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)

# DataLoader with sampler
loader = DataLoader(dataset, batch_size=2, sampler=sampler)

# Iterate and print batches
for batch_features, batch_labels in loader:
    print('Batch features:', batch_features.flatten().tolist())
    print('Batch labels:', batch_labels.tolist())
Output
Batch features: [4.0, 3.0] Batch labels: [1, 0] Batch features: [6.0, 3.0] Batch labels: [1, 0] Batch features: [1.0, 5.0] Batch labels: [0, 1]
⚠️

Common Pitfalls

  • Not using replacement: Setting replacement=False can cause errors if num_samples is larger than dataset size.
  • Incorrect weights length: Weights must match the dataset size exactly.
  • Ignoring class imbalance: Not computing weights based on class frequency defeats the purpose of balancing.
python
import torch
from torch.utils.data import WeightedRandomSampler

# Wrong: weights length mismatch
weights_wrong = [0.1, 0.9]  # Dataset has 3 samples
try:
    sampler_wrong = WeightedRandomSampler(weights_wrong, num_samples=3, replacement=True)
except Exception as e:
    print('Error:', e)

# Right: weights length matches dataset
weights_right = [0.1, 0.9, 0.5]
sampler_right = WeightedRandomSampler(weights_right, num_samples=3, replacement=True)
print('Sampler created successfully with correct weights length')
Output
Error: weights length does not match dataset length Sampler created successfully with correct weights length
📊

Quick Reference

WeightedRandomSampler Quick Tips:

  • Use weights to assign sampling probability per data point.
  • Set replacement=True to allow repeated samples in one epoch.
  • Pass the sampler to DataLoader via the sampler argument.
  • Calculate weights as inverse of class frequency to balance classes.
  • Ensure weights length equals dataset size to avoid errors.

Key Takeaways

WeightedRandomSampler samples dataset elements based on specified weights to balance data during training.
Always set replacement=True to allow sampling with replacement for weighted sampling.
Weights must be a sequence matching dataset size, often computed as inverse class frequencies.
Pass WeightedRandomSampler to DataLoader's sampler argument to control batch sampling.
Check weights length and replacement setting to avoid common errors.