0
0
PytorchHow-ToBeginner · 4 min read

How to Use torch.distributed for Distributed Training in PyTorch

Use torch.distributed to run PyTorch models across multiple GPUs or machines by initializing a process group with init_process_group. Then, wrap your model with DistributedDataParallel to synchronize gradients during training.
📐

Syntax

The main steps to use torch.distributed are:

  • torch.distributed.init_process_group(backend, init_method, world_size, rank): Initializes the communication between processes.
  • torch.nn.parallel.DistributedDataParallel(model): Wraps your model to handle gradient synchronization.
  • torch.distributed.destroy_process_group(): Cleans up the process group after training.

Parameters explained:

  • backend: Communication backend like nccl (GPU) or gloo (CPU).
  • init_method: URL specifying how to initialize, e.g., TCP address.
  • world_size: Total number of processes.
  • rank: Unique ID of the current process.
python
import torch.distributed as dist

def setup(rank, world_size):
    dist.init_process_group(
        backend='nccl',
        init_method='tcp://127.0.0.1:29500',
        world_size=world_size,
        rank=rank
    )

# Wrap model for distributed training
# model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])

# Cleanup
# dist.destroy_process_group()
💻

Example

This example shows a minimal setup for distributed training on a single machine with 2 GPUs. It initializes the process group, wraps a simple model, and runs one training step.

python
import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.multiprocessing import spawn

def setup(rank, world_size):
    dist.init_process_group(
        backend='nccl',
        init_method='tcp://127.0.0.1:29500',
        world_size=world_size,
        rank=rank
    )

def cleanup():
    dist.destroy_process_group()

def demo_basic(rank, world_size):
    print(f"Running basic DDP example on rank {rank}.")
    setup(rank, world_size)

    # Create model and move it to GPU with id rank
    model = nn.Linear(10, 10).to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    # Dummy input and target
    inputs = torch.randn(20, 10).to(rank)
    targets = torch.randn(20, 10).to(rank)

    optimizer.zero_grad()
    outputs = ddp_model(inputs)
    loss = loss_fn(outputs, targets)
    loss.backward()
    optimizer.step()

    print(f"Rank {rank} loss: {loss.item()}")
    cleanup()

def run_demo():
    world_size = 2
    spawn(demo_basic, args=(world_size,), nprocs=world_size, join=True)

if __name__ == "__main__":
    run_demo()
Output
Running basic DDP example on rank 0. Rank 0 loss: 1.123456 Running basic DDP example on rank 1. Rank 1 loss: 1.234567
⚠️

Common Pitfalls

  • Not setting unique ranks: Each process must have a unique rank from 0 to world_size - 1.
  • Wrong backend: Use nccl for GPUs and gloo for CPUs.
  • Not wrapping model with DDP: Without DistributedDataParallel, gradients won't sync.
  • Forgetting to call destroy_process_group(): Can cause hanging processes.
  • Using the same port for multiple runs: Change init_method port if you restart quickly.
python
import torch.distributed as dist

# Wrong: same rank for all processes
# dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:29500', world_size=2, rank=0)

# Right: unique rank per process
# dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:29500', world_size=2, rank=rank)
📊

Quick Reference

Remember these key points when using torch.distributed:

  • Initialize with init_process_group before training.
  • Wrap your model with DistributedDataParallel for gradient sync.
  • Use spawn or launch multiple processes manually.
  • Clean up with destroy_process_group after training.

Key Takeaways

Initialize the process group with unique ranks and correct backend before training.
Wrap your model with DistributedDataParallel to synchronize gradients across processes.
Use torch.multiprocessing.spawn to launch multiple training processes easily.
Always clean up with destroy_process_group to avoid hanging processes.
Choose the right backend: 'nccl' for GPUs and 'gloo' for CPUs.