0
0
PytorchHow-ToBeginner · 3 min read

How to Use num_workers in PyTorch DataLoader for Faster Data Loading

In PyTorch, set the num_workers parameter in DataLoader to the number of subprocesses you want for loading data in parallel. Increasing num_workers speeds up data loading by using multiple CPU cores, but setting it too high can cause overhead or errors.
📐

Syntax

The num_workers parameter in torch.utils.data.DataLoader controls how many subprocesses are used to load the data. A higher number means more parallel data loading.

  • dataset: Your dataset object.
  • batch_size: Number of samples per batch.
  • shuffle: Whether to shuffle data each epoch.
  • num_workers: Number of subprocesses for loading data (default is 0, meaning data loads in the main process).
python
DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
💻

Example

This example shows how to use num_workers to speed up loading the MNIST dataset. It compares loading with num_workers=0 (single process) and num_workers=2 (two subprocesses).

python
import time
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

transform = transforms.ToTensor()
mnist_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# DataLoader with single worker (main process)
dataloader_single = DataLoader(mnist_dataset, batch_size=64, shuffle=True, num_workers=0)

# DataLoader with 2 workers (subprocesses)
dataloader_multi = DataLoader(mnist_dataset, batch_size=64, shuffle=True, num_workers=2)

def measure_loading_time(dataloader):
    start = time.time()
    for _ in dataloader:
        pass
    end = time.time()
    return end - start

single_time = measure_loading_time(dataloader_single)
multi_time = measure_loading_time(dataloader_multi)

print(f"Loading time with num_workers=0: {single_time:.2f} seconds")
print(f"Loading time with num_workers=2: {multi_time:.2f} seconds")
Output
Loading time with num_workers=0: 5.12 seconds Loading time with num_workers=2: 3.45 seconds
⚠️

Common Pitfalls

Common mistakes when using num_workers include:

  • Setting num_workers too high can cause your system to slow down or run out of memory.
  • On Windows, you must protect your data loading code with if __name__ == '__main__': to avoid errors.
  • Using num_workers > 0 can cause issues if your dataset or transforms are not picklable (cannot be serialized).
  • Debugging is harder with multiple workers because errors may be hidden in subprocesses.

Example of Windows-safe usage:

python
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

def main():
    transform = transforms.ToTensor()
    dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=2)
    for images, labels in dataloader:
        pass

if __name__ == '__main__':
    main()
📊

Quick Reference

ParameterDescriptionDefaultNotes
datasetDataset object to load data fromNoneRequired
batch_sizeNumber of samples per batch1Adjust for memory and speed
shuffleShuffle data each epochFalseUseful for training
num_workersNumber of subprocesses for loading data00 means main process; >0 uses parallel workers

Key Takeaways

Set num_workers > 0 to load data in parallel and speed up training.
Too many workers can cause slowdowns or memory issues; tune based on your CPU and RAM.
On Windows, always use if __name__ == '__main__' guard when num_workers > 0.
Ensure your dataset and transforms are picklable to avoid errors with multiple workers.
Start with num_workers=0 to debug, then increase for better performance.