How to Use num_workers in PyTorch DataLoader for Faster Data Loading
In PyTorch, set the
num_workers parameter in DataLoader to the number of subprocesses you want for loading data in parallel. Increasing num_workers speeds up data loading by using multiple CPU cores, but setting it too high can cause overhead or errors.Syntax
The num_workers parameter in torch.utils.data.DataLoader controls how many subprocesses are used to load the data. A higher number means more parallel data loading.
dataset: Your dataset object.batch_size: Number of samples per batch.shuffle: Whether to shuffle data each epoch.num_workers: Number of subprocesses for loading data (default is 0, meaning data loads in the main process).
python
DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
Example
This example shows how to use num_workers to speed up loading the MNIST dataset. It compares loading with num_workers=0 (single process) and num_workers=2 (two subprocesses).
python
import time import torch from torch.utils.data import DataLoader from torchvision import datasets, transforms transform = transforms.ToTensor() mnist_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) # DataLoader with single worker (main process) dataloader_single = DataLoader(mnist_dataset, batch_size=64, shuffle=True, num_workers=0) # DataLoader with 2 workers (subprocesses) dataloader_multi = DataLoader(mnist_dataset, batch_size=64, shuffle=True, num_workers=2) def measure_loading_time(dataloader): start = time.time() for _ in dataloader: pass end = time.time() return end - start single_time = measure_loading_time(dataloader_single) multi_time = measure_loading_time(dataloader_multi) print(f"Loading time with num_workers=0: {single_time:.2f} seconds") print(f"Loading time with num_workers=2: {multi_time:.2f} seconds")
Output
Loading time with num_workers=0: 5.12 seconds
Loading time with num_workers=2: 3.45 seconds
Common Pitfalls
Common mistakes when using num_workers include:
- Setting
num_workerstoo high can cause your system to slow down or run out of memory. - On Windows, you must protect your data loading code with
if __name__ == '__main__':to avoid errors. - Using
num_workers > 0can cause issues if your dataset or transforms are not picklable (cannot be serialized). - Debugging is harder with multiple workers because errors may be hidden in subprocesses.
Example of Windows-safe usage:
python
import torch from torch.utils.data import DataLoader from torchvision import datasets, transforms def main(): transform = transforms.ToTensor() dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=2) for images, labels in dataloader: pass if __name__ == '__main__': main()
Quick Reference
| Parameter | Description | Default | Notes |
|---|---|---|---|
| dataset | Dataset object to load data from | None | Required |
| batch_size | Number of samples per batch | 1 | Adjust for memory and speed |
| shuffle | Shuffle data each epoch | False | Useful for training |
| num_workers | Number of subprocesses for loading data | 0 | 0 means main process; >0 uses parallel workers |
Key Takeaways
Set num_workers > 0 to load data in parallel and speed up training.
Too many workers can cause slowdowns or memory issues; tune based on your CPU and RAM.
On Windows, always use if __name__ == '__main__' guard when num_workers > 0.
Ensure your dataset and transforms are picklable to avoid errors with multiple workers.
Start with num_workers=0 to debug, then increase for better performance.