How to Use Gradient Clipping in PyTorch for Stable Training
In PyTorch, use
torch.nn.utils.clip_grad_norm_ or torch.nn.utils.clip_grad_value_ to clip gradients before the optimizer step. This limits gradient size, preventing exploding gradients and improving training stability.Syntax
Gradient clipping in PyTorch is done using utility functions that modify gradients in-place before updating model weights.
torch.nn.utils.clip_grad_norm_(parameters, max_norm): Clips gradients so their norm does not exceedmax_norm.torch.nn.utils.clip_grad_value_(parameters, clip_value): Clips gradients so their absolute values do not exceedclip_value.
parameters is usually model.parameters().
python
torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=2.0)
torch.nn.utils.clip_grad_value_(parameters, clip_value)Example
This example shows how to clip gradients by norm during training a simple neural network on random data. It clips gradients to a maximum norm of 1.0 before the optimizer step.
python
import torch import torch.nn as nn import torch.optim as optim # Simple model model = nn.Linear(10, 1) optimizer = optim.SGD(model.parameters(), lr=0.1) # Dummy input and target inputs = torch.randn(5, 10) targets = torch.randn(5, 1) criterion = nn.MSELoss() # Forward pass outputs = model(inputs) loss = criterion(outputs, targets) # Backward pass loss.backward() # Clip gradients by norm torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # Optimizer step optimizer.step() # Print clipped gradients for name, param in model.named_parameters(): if param.grad is not None: print(f"{name} grad norm: {param.grad.norm().item():.4f}")
Output
weight grad norm: 0.9999
bias grad norm: 0.9999
Common Pitfalls
- Not clipping gradients before
optimizer.step()causes no effect. - Clipping after zeroing gradients (
optimizer.zero_grad()) will clip zero gradients. - Using too small
max_normcan slow training by shrinking gradients excessively. - For RNNs, clipping by norm is usually better than clipping by value.
python
import torch import torch.nn as nn import torch.optim as optim model = nn.Linear(10, 1) optimizer = optim.SGD(model.parameters(), lr=0.1) inputs = torch.randn(5, 10) targets = torch.randn(5, 1) criterion = nn.MSELoss() outputs = model(inputs) loss = criterion(outputs, targets) loss.backward() # WRONG: Clipping after zero_grad (no effect) optimizer.zero_grad() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() # RIGHT: Clip before optimizer step and after backward loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step()
Quick Reference
Gradient Clipping Cheat Sheet:
| Function | Purpose | Typical Use |
|---|---|---|
clip_grad_norm_ | Clip gradients by norm | clip_grad_norm_(model.parameters(), max_norm=1.0) |
clip_grad_value_ | Clip gradients by value | clip_grad_value_(model.parameters(), clip_value=0.5) |
| Call order | After loss.backward(), before optimizer.step() | loss.backward(); clip_grad_norm_(...); optimizer.step() |
| Function | Purpose | Typical Use |
|---|---|---|
| clip_grad_norm_ | Clip gradients by norm | clip_grad_norm_(model.parameters(), max_norm=1.0) |
| clip_grad_value_ | Clip gradients by value | clip_grad_value_(model.parameters(), clip_value=0.5) |
| Call order | After loss.backward(), before optimizer.step() | loss.backward(); clip_grad_norm_(...); optimizer.step() |
Key Takeaways
Always clip gradients after calling loss.backward() and before optimizer.step().
Use clip_grad_norm_ to limit the overall gradient size by norm, which is common for RNNs.
Avoid clipping gradients after zero_grad() as gradients will be zeroed out.
Choose max_norm or clip_value carefully to avoid hurting training progress.
Gradient clipping helps prevent exploding gradients and stabilizes training.