import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.utils.data import DataLoader, DistributedSampler
from torchvision import datasets, transforms
from torchvision.models import resnet18
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group('nccl', rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def train(rank, world_size):
setup(rank, world_size)
torch.manual_seed(42)
device = torch.device(f'cuda:{rank}')
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = datasets.CIFAR10('./data', train=True, download=True, transform=transform)
val_dataset = datasets.CIFAR10('./data', train=False, download=True, transform=transform)
train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True)
train_loader = DataLoader(train_dataset, batch_size=32, sampler=train_sampler)
val_loader = DataLoader(val_dataset, batch_size=32)
model = resnet18(num_classes=10).to(device)
model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
epochs = 10
for epoch in range(epochs):
model.train()
train_sampler.set_epoch(epoch)
total_loss = 0.0
correct = 0
total = 0
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item() * inputs.size(0)
_, predicted = torch.max(outputs, 1)
correct += (predicted == labels).sum().item()
total += labels.size(0)
# Aggregate training metrics across GPUs
loss_tensor = torch.tensor(total_loss, device=device)
correct_tensor = torch.tensor(correct, device=device)
total_tensor = torch.tensor(total, device=device)
dist.all_reduce(loss_tensor)
dist.all_reduce(correct_tensor)
dist.all_reduce(total_tensor)
train_loss = loss_tensor.item() / total_tensor.item()
train_acc = correct_tensor.item() / total_tensor.item() * 100
if rank == 0:
model.eval()
val_loss = 0.0
val_correct = 0
val_total = 0
with torch.no_grad():
for inputs, labels in val_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
val_loss += loss.item() * inputs.size(0)
_, predicted = torch.max(outputs, 1)
val_correct += (predicted == labels).sum().item()
val_total += labels.size(0)
val_loss /= val_total
val_acc = val_correct / val_total * 100
print(f'Epoch {epoch+1}/{epochs} - Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
cleanup()
if __name__ == '__main__':
world_size = 2
mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)