How to Use Multiple GPUs in PyTorch for Faster Training
To use multiple GPUs in PyTorch, wrap your model with
torch.nn.DataParallel or use torch.nn.parallel.DistributedDataParallel for better performance. This allows PyTorch to split input data across GPUs and combine results automatically during training.Syntax
PyTorch provides two main ways to use multiple GPUs: DataParallel and DistributedDataParallel.
- DataParallel: Wrap your model with
torch.nn.DataParallel(model). It splits input batches across GPUs and gathers outputs. - DistributedDataParallel: More efficient for multi-GPU training, especially across multiple machines. Requires initializing a process group and wrapping the model with
torch.nn.parallel.DistributedDataParallel(model).
python
import torch import torch.nn as nn # Using DataParallel model = YourModel() model = nn.DataParallel(model) # Wrap model for multi-GPU # Using DistributedDataParallel (simplified example) import torch.distributed as dist dist.init_process_group(backend='nccl') model = YourModel().to(device) model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
Example
This example shows how to use DataParallel to train a simple model on multiple GPUs. It automatically splits input data and combines outputs.
python
import torch import torch.nn as nn import torch.optim as optim # Simple model class SimpleNet(nn.Module): def __init__(self): super().__init__() self.fc = nn.Linear(10, 2) def forward(self, x): return self.fc(x) # Check if GPUs are available if torch.cuda.device_count() > 1: print(f"Using {torch.cuda.device_count()} GPUs") model = SimpleNet() if torch.cuda.is_available(): model = nn.DataParallel(model) # Wrap model for multi-GPU model = model.cuda() # Dummy data inputs = torch.randn(16, 10).cuda() if torch.cuda.is_available() else torch.randn(16, 10) labels = torch.randint(0, 2, (16,)).cuda() if torch.cuda.is_available() else torch.randint(0, 2, (16,)) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.01) # Training step model.train() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() print(f"Loss: {loss.item():.4f}")
Output
Using 2 GPUs
Loss: 0.6931
Common Pitfalls
Common mistakes when using multiple GPUs in PyTorch include:
- Not moving input data to the correct device (GPU). Inputs must be on the same device as the model.
- Using
DataParallelon CPU-only machines will not speed up training. - Accessing model attributes directly after wrapping with
DataParallelrequires usingmodel.module. DistributedDataParallelrequires proper setup of process groups and environment variables.
python
import torch import torch.nn as nn model = nn.Linear(10, 2) model = nn.DataParallel(model) # Wrong: accessing original model attribute # print(model.weight) # AttributeError # Right: access wrapped model print(model.module.weight)
Output
Parameter containing:
tensor([[ 0.1234, -0.2345, ..., 0.3456]], requires_grad=True)
Quick Reference
Tips for using multiple GPUs in PyTorch:
- Use
DataParallelfor simple multi-GPU on one machine. - Use
DistributedDataParallelfor better speed and multi-node training. - Always move inputs and targets to the correct GPU device.
- Access the original model with
model.modulewhen usingDataParallel. - Set
backend='nccl'for GPU communication inDistributedDataParallel.
Key Takeaways
Wrap your model with torch.nn.DataParallel or torch.nn.parallel.DistributedDataParallel to use multiple GPUs.
Always move your input data to the GPU devices before feeding it to the model.
Access the original model via model.module when using DataParallel.
DistributedDataParallel offers better performance and scalability than DataParallel.
Proper setup of device IDs and process groups is essential for DistributedDataParallel.