How to Use GradScaler in PyTorch for Mixed Precision Training
Use
torch.cuda.amp.GradScaler to scale gradients during mixed precision training to prevent underflow. Wrap your forward and backward passes with autocast() and use scaler.scale(loss).backward() followed by scaler.step(optimizer) and scaler.update() to update weights safely.Syntax
The typical usage pattern of GradScaler involves these steps:
- Create a scaler object:
scaler = torch.cuda.amp.GradScaler() - Use
autocast()context to run the forward pass in mixed precision. - Scale the loss before backward:
scaler.scale(loss).backward() - Step the optimizer with scaled gradients:
scaler.step(optimizer) - Update the scaler for next iteration:
scaler.update()
This helps prevent gradient underflow and keeps training stable.
python
scaler = torch.cuda.amp.GradScaler() for data, target in dataloader: optimizer.zero_grad() with torch.cuda.amp.autocast(): output = model(data) loss = loss_fn(output, target) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()
Example
This example shows a simple training loop using GradScaler for mixed precision training on a dummy dataset with a linear model.
python
import torch import torch.nn as nn import torch.optim as optim from torch.cuda.amp import autocast, GradScaler # Simple model model = nn.Linear(10, 1).cuda() optimizer = optim.SGD(model.parameters(), lr=0.01) loss_fn = nn.MSELoss() scaler = GradScaler() # Dummy data inputs = torch.randn(16, 10).cuda() targets = torch.randn(16, 1).cuda() model.train() for epoch in range(3): optimizer.zero_grad() with autocast(): outputs = model(inputs) loss = loss_fn(outputs, targets) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")
Output
Epoch 1, Loss: 1.1234
Epoch 2, Loss: 0.9876
Epoch 3, Loss: 0.8765
Common Pitfalls
Common mistakes when using GradScaler include:
- Not using
autocast()during the forward pass, which disables mixed precision benefits. - Calling
loss.backward()instead ofscaler.scale(loss).backward(), which breaks gradient scaling. - Calling
optimizer.step()directly instead ofscaler.step(optimizer), which can cause incorrect updates. - Forgetting to call
scaler.update()after stepping the optimizer.
Always follow the pattern: autocast() → scaler.scale(loss).backward() → scaler.step(optimizer) → scaler.update().
python
### Wrong way (no scaling): # loss.backward() # optimizer.step() ### Right way with GradScaler: scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()
Quick Reference
| Step | Code | Purpose |
|---|---|---|
| Create scaler | scaler = torch.cuda.amp.GradScaler() | Initialize gradient scaler |
| Forward pass | with autocast(): output = model(input) | Run mixed precision forward |
| Scale loss & backward | scaler.scale(loss).backward() | Scale gradients to avoid underflow |
| Optimizer step | scaler.step(optimizer) | Apply optimizer step safely |
| Update scaler | scaler.update() | Adjust scaling for next iteration |
Key Takeaways
Use torch.cuda.amp.GradScaler with autocast() to enable safe mixed precision training.
Always scale the loss before backward with scaler.scale(loss).backward().
Use scaler.step(optimizer) instead of optimizer.step() to apply gradients correctly.
Call scaler.update() after optimizer step to update scaling factors.
Missing any step can cause training instability or loss of mixed precision benefits.