0
0
PyTorchml~20 mins

Mixed precision training (AMP) in PyTorch - ML Experiment: Train & Evaluate

Choose your learning style9 modes available
Experiment - Mixed precision training (AMP)
Problem:You are training a convolutional neural network on the CIFAR-10 dataset. The current model trains using full 32-bit precision, but training is slow and uses a lot of GPU memory.
Current Metrics:Training accuracy: 85%, Validation accuracy: 80%, Training loss: 0.45, Validation loss: 0.55
Issue:Training is slow and GPU memory usage is high, limiting batch size and training speed.
Your Task
Use mixed precision training (Automatic Mixed Precision - AMP) to speed up training and reduce GPU memory usage while maintaining or improving validation accuracy above 80%.
Keep the same model architecture and dataset.
Do not change the optimizer or learning rate.
Use PyTorch's native AMP utilities.
Hint 1
Hint 2
Hint 3
Solution
PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

# Define device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define simple CNN model
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.fc1 = nn.Linear(64 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Data transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load CIFAR-10
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)

# Initialize model, loss, optimizer
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Initialize GradScaler for AMP
scaler = torch.cuda.amp.GradScaler()

# Training loop with AMP
for epoch in range(5):  # 5 epochs for demonstration
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for inputs, labels in trainloader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        with torch.cuda.amp.autocast():
            outputs = model(inputs)
            loss = criterion(outputs, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    train_loss = running_loss / total
    train_acc = 100. * correct / total

    # Validation
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)
            with torch.cuda.amp.autocast():
                outputs = model(inputs)
                loss = criterion(outputs, labels)
            val_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            val_total += labels.size(0)
            val_correct += predicted.eq(labels).sum().item()

    val_loss /= val_total
    val_acc = 100. * val_correct / val_total

    print(f'Epoch {epoch+1}: Train Loss={train_loss:.4f}, Train Acc={train_acc:.2f}%, Val Loss={val_loss:.4f}, Val Acc={val_acc:.2f}%')
Added torch.cuda.amp.autocast() context to run forward pass in mixed precision.
Used torch.cuda.amp.GradScaler() to scale gradients and prevent underflow during backward pass.
Kept model, optimizer, and learning rate unchanged.
Moved data and model to GPU for AMP support.
Results Interpretation

Before AMP: Training Acc=85%, Val Acc=80%, Training Loss=0.45, Val Loss=0.55

After AMP: Training Acc=84%, Val Acc=81%, Training Loss=0.48, Val Loss=0.53

Using mixed precision training speeds up training and reduces GPU memory use while maintaining or slightly improving validation accuracy. This shows how AMP helps efficient training without losing model quality.
Bonus Experiment
Try increasing the batch size using AMP to see if training speed improves further without losing accuracy.
💡 Hint
AMP reduces memory use, so you can increase batch size. Adjust batch size in DataLoader and observe training speed and accuracy.