How to Use torch.distributed for Distributed Training in PyTorch
Use
torch.distributed to run PyTorch models across multiple GPUs or machines by initializing a process group with init_process_group. Then, wrap your model with DistributedDataParallel to synchronize gradients during training.Syntax
The main steps to use torch.distributed are:
torch.distributed.init_process_group(backend, init_method, world_size, rank): Initializes the communication between processes.torch.nn.parallel.DistributedDataParallel(model): Wraps your model to handle gradient synchronization.torch.distributed.destroy_process_group(): Cleans up the process group after training.
Parameters explained:
- backend: Communication backend like
nccl(GPU) orgloo(CPU). - init_method: URL specifying how to initialize, e.g., TCP address.
- world_size: Total number of processes.
- rank: Unique ID of the current process.
python
import torch.distributed as dist def setup(rank, world_size): dist.init_process_group( backend='nccl', init_method='tcp://127.0.0.1:29500', world_size=world_size, rank=rank ) # Wrap model for distributed training # model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank]) # Cleanup # dist.destroy_process_group()
Example
This example shows a minimal setup for distributed training on a single machine with 2 GPUs. It initializes the process group, wraps a simple model, and runs one training step.
python
import os import torch import torch.distributed as dist import torch.nn as nn import torch.optim as optim from torch.nn.parallel import DistributedDataParallel as DDP from torch.multiprocessing import spawn def setup(rank, world_size): dist.init_process_group( backend='nccl', init_method='tcp://127.0.0.1:29500', world_size=world_size, rank=rank ) def cleanup(): dist.destroy_process_group() def demo_basic(rank, world_size): print(f"Running basic DDP example on rank {rank}.") setup(rank, world_size) # Create model and move it to GPU with id rank model = nn.Linear(10, 10).to(rank) ddp_model = DDP(model, device_ids=[rank]) loss_fn = nn.MSELoss() optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) # Dummy input and target inputs = torch.randn(20, 10).to(rank) targets = torch.randn(20, 10).to(rank) optimizer.zero_grad() outputs = ddp_model(inputs) loss = loss_fn(outputs, targets) loss.backward() optimizer.step() print(f"Rank {rank} loss: {loss.item()}") cleanup() def run_demo(): world_size = 2 spawn(demo_basic, args=(world_size,), nprocs=world_size, join=True) if __name__ == "__main__": run_demo()
Output
Running basic DDP example on rank 0.
Rank 0 loss: 1.123456
Running basic DDP example on rank 1.
Rank 1 loss: 1.234567
Common Pitfalls
- Not setting unique ranks: Each process must have a unique
rankfrom 0 toworld_size - 1. - Wrong backend: Use
ncclfor GPUs andgloofor CPUs. - Not wrapping model with DDP: Without
DistributedDataParallel, gradients won't sync. - Forgetting to call
destroy_process_group(): Can cause hanging processes. - Using the same port for multiple runs: Change
init_methodport if you restart quickly.
python
import torch.distributed as dist # Wrong: same rank for all processes # dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:29500', world_size=2, rank=0) # Right: unique rank per process # dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:29500', world_size=2, rank=rank)
Quick Reference
Remember these key points when using torch.distributed:
- Initialize with
init_process_groupbefore training. - Wrap your model with
DistributedDataParallelfor gradient sync. - Use
spawnor launch multiple processes manually. - Clean up with
destroy_process_groupafter training.
Key Takeaways
Initialize the process group with unique ranks and correct backend before training.
Wrap your model with DistributedDataParallel to synchronize gradients across processes.
Use torch.multiprocessing.spawn to launch multiple training processes easily.
Always clean up with destroy_process_group to avoid hanging processes.
Choose the right backend: 'nccl' for GPUs and 'gloo' for CPUs.