0
0
PyTorchml~5 mins

Why distributed training handles large models in PyTorch

Choose your learning style9 modes available
Introduction

Distributed training splits a big model into smaller parts and runs them on many computers. This helps handle models too large for one computer's memory.

When your model is too big to fit into one GPU's memory.
When you want to train faster by using multiple GPUs or machines.
When working with very large datasets that require more computing power.
When you want to improve model accuracy by training bigger models.
When your project needs to scale up beyond a single device's limits.
Syntax
PyTorch
import torch
import torch.nn as nn
import torch.distributed as dist

# Initialize distributed training
dist.init_process_group(backend='nccl')

# Create model and move to GPU
model = nn.Linear(10, 10).cuda()

# Wrap model for distributed training
model = nn.parallel.DistributedDataParallel(model, device_ids=[torch.cuda.current_device()])

# Training loop here...

Use dist.init_process_group to start distributed training.

Wrap your model with DistributedDataParallel to split work across GPUs.

Examples
Basic setup for distributed training on GPUs using NCCL backend.
PyTorch
dist.init_process_group(backend='nccl')
model = nn.Linear(10, 10).cuda()
model = nn.parallel.DistributedDataParallel(model, device_ids=[torch.cuda.current_device()])
Setup for distributed training on CPUs using Gloo backend.
PyTorch
dist.init_process_group(backend='gloo')
model = nn.Linear(10, 10).to('cpu')
model = nn.parallel.DistributedDataParallel(model)
Sample Model

This code runs a simple model training on 2 processes using distributed training. Each process prints its loss for 3 epochs.

PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
import torch.multiprocessing as mp

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(5, 2)
    def forward(self, x):
        return self.linear(x)

def train(rank, world_size):
    dist.init_process_group('gloo', rank=rank, world_size=world_size)
    model = SimpleModel().to(rank)
    model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])
    optimizer = optim.SGD(model.parameters(), lr=0.01)
    criterion = nn.MSELoss()

    # Dummy data
    inputs = torch.randn(10, 5).to(rank)
    targets = torch.randn(10, 2).to(rank)

    model.train()
    for epoch in range(3):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        print(f'Rank {rank}, Epoch {epoch}, Loss: {loss.item():.4f}')

    dist.destroy_process_group()

if __name__ == '__main__':
    world_size = 2
    mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)
OutputSuccess
Important Notes

Distributed training requires all processes to communicate and synchronize.

Using multiple GPUs can speed up training and allow bigger models.

Make sure to set the correct backend and device IDs for your hardware.

Summary

Distributed training splits big models across many devices to handle memory limits.

It helps train faster and use larger models than one device can handle.

PyTorch provides easy tools like DistributedDataParallel to do this.