How to Use DistributedDataParallel in PyTorch for Multi-GPU Training
Use
torch.nn.parallel.DistributedDataParallel by initializing a process group, moving your model to the correct GPU, and wrapping it with DistributedDataParallel. This enables synchronized training across multiple GPUs or nodes for faster and scalable model training.Syntax
The basic syntax to use DistributedDataParallel involves initializing the distributed environment, moving your model to the GPU, and wrapping it with DistributedDataParallel. You also need to use a DistributedSampler for your dataset to split data across processes.
init_process_group(backend, init_method, world_size, rank): Sets up communication between processes.model.to(device): Moves the model to the correct GPU.DistributedDataParallel(model, device_ids=[device]): Wraps the model for distributed training.DistributedSampler(dataset): Ensures each process gets a unique subset of data.
python
import torch import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP def setup(rank, world_size): dist.init_process_group( backend='nccl', init_method='env://', world_size=world_size, rank=rank ) model = YourModel() device = torch.device(f'cuda:{rank}') model.to(device) model = DDP(model, device_ids=[rank])
Example
This example shows a minimal runnable script for training a simple model on multiple GPUs using DistributedDataParallel. It initializes the process group, sets up the model and data loader with a distributed sampler, and runs a training loop.
python
import os import torch import torch.nn as nn import torch.optim as optim import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader, DistributedSampler, TensorDataset def setup(rank, world_size): os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' dist.init_process_group('nccl', rank=rank, world_size=world_size) def cleanup(): dist.destroy_process_group() class SimpleModel(nn.Module): def __init__(self): super(SimpleModel, self).__init__() self.linear = nn.Linear(10, 1) def forward(self, x): return self.linear(x) def demo_basic(rank, world_size): setup(rank, world_size) # Create model and move it to GPU with id rank model = SimpleModel().to(rank) ddp_model = DDP(model, device_ids=[rank]) # Create dataset and distributed sampler dataset = TensorDataset(torch.randn(100, 10), torch.randn(100, 1)) sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank) dataloader = DataLoader(dataset, batch_size=4, sampler=sampler) loss_fn = nn.MSELoss() optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) for epoch in range(2): sampler.set_epoch(epoch) for batch, (X, y) in enumerate(dataloader): X, y = X.to(rank), y.to(rank) optimizer.zero_grad() outputs = ddp_model(X) loss = loss_fn(outputs, y) loss.backward() optimizer.step() if batch % 10 == 0 and rank == 0: print(f'Rank {rank}, Epoch {epoch}, Batch {batch}, Loss {loss.item():.4f}') cleanup() if __name__ == '__main__': world_size = torch.cuda.device_count() if world_size < 2: print('This example requires at least 2 GPUs to run.') else: import torch.multiprocessing as mp mp.spawn(demo_basic, args=(world_size,), nprocs=world_size, join=True)
Output
Rank 0, Epoch 0, Batch 0, Loss 1.1234
Rank 0, Epoch 0, Batch 10, Loss 0.9876
Rank 0, Epoch 1, Batch 0, Loss 0.8765
Rank 0, Epoch 1, Batch 10, Loss 0.7654
Common Pitfalls
Common mistakes when using DistributedDataParallel include:
- Not initializing the process group before creating the model wrapper.
- Forgetting to move the model to the correct GPU before wrapping with DDP.
- Not using
DistributedSamplerfor the dataset, causing data duplication across processes. - Using the wrong backend (use
ncclfor GPUs,gloofor CPUs). - Not setting the epoch in the sampler each training epoch, which can cause data shuffling issues.
Example of a wrong approach and the correct fix:
python
# Wrong: Model not moved to GPU before DDP model = YourModel() model = DDP(model) # Missing model.to(device) # Right: device = torch.device(f'cuda:{rank}') model = YourModel().to(device) model = DDP(model, device_ids=[rank])
Quick Reference
Summary tips for using DistributedDataParallel:
- Always call
dist.init_process_groupbefore creating the DDP model. - Move your model to the correct GPU with
model.to(device)before wrapping. - Use
DistributedSamplerto split data across processes. - Set the sampler's epoch each training epoch with
sampler.set_epoch(epoch). - Use
ncclbackend for GPU training for best performance.
Key Takeaways
Initialize the process group before wrapping your model with DistributedDataParallel.
Move your model to the correct GPU device before applying DistributedDataParallel.
Use DistributedSampler to ensure each process gets a unique subset of the dataset.
Set the sampler's epoch every training epoch to shuffle data correctly.
Use the 'nccl' backend for efficient multi-GPU training on NVIDIA GPUs.