How to Clip Gradients in PyTorch: Syntax and Example
In PyTorch, you can clip gradients using
torch.nn.utils.clip_grad_norm_ or clip_grad_value_ to limit their size during training. This helps prevent exploding gradients by scaling or capping gradients before the optimizer updates model weights.Syntax
Use torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=2) to clip gradients by their norm. parameters is the model parameters, max_norm is the maximum allowed norm, and norm_type defines the type of norm (default is L2 norm).
Alternatively, use torch.nn.utils.clip_grad_value_(parameters, clip_value) to clip gradients by value, capping each gradient element to the range [-clip_value, clip_value].
python
torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=2)
torch.nn.utils.clip_grad_value_(parameters, clip_value)Example
This example shows how to clip gradients by norm during training a simple linear model. It clips gradients to a max norm of 1.0 before the optimizer step.
python
import torch import torch.nn as nn import torch.optim as optim # Simple linear model model = nn.Linear(2, 1) optimizer = optim.SGD(model.parameters(), lr=0.1) # Dummy input and target inputs = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) targets = torch.tensor([[1.0], [2.0]]) criterion = nn.MSELoss() # Forward pass outputs = model(inputs) loss = criterion(outputs, targets) # Backward pass loss.backward() # Clip gradients by norm torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # Check gradient norms total_norm = 0 for p in model.parameters(): param_norm = p.grad.data.norm(2) total_norm += param_norm.item() ** 2 total_norm = total_norm ** 0.5 print(f"Total gradient norm after clipping: {total_norm:.4f}") # Optimizer step optimizer.step()
Output
Total gradient norm after clipping: 1.0000
Common Pitfalls
- Not calling
loss.backward()before clipping gradients will cause errors because gradients do not exist yet. - Clipping gradients after
optimizer.step()has no effect; always clip before the optimizer updates weights. - Using too small
max_normcan slow training by overly shrinking gradients. - For models with multiple parameter groups, ensure clipping is applied to all parameters.
python
import torch import torch.nn as nn import torch.optim as optim model = nn.Linear(2, 1) optimizer = optim.SGD(model.parameters(), lr=0.1) inputs = torch.tensor([[1.0, 2.0]]) targets = torch.tensor([[1.0]]) criterion = nn.MSELoss() outputs = model(inputs) loss = criterion(outputs, targets) # WRONG: clipping before backward # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # No gradients yet loss.backward() # WRONG: clipping after optimizer step optimizer.step() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # Too late # RIGHT: clip after backward, before optimizer step loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step()
Quick Reference
Gradient Clipping Methods in PyTorch:
clip_grad_norm_: Clips gradients by their norm (recommended for most cases).clip_grad_value_: Clips gradients by value, capping each element.
Usage Tips:
- Call after
loss.backward()and beforeoptimizer.step(). - Choose
max_normbased on your model and training stability. - Use clipping to prevent exploding gradients in deep or recurrent networks.
Key Takeaways
Always clip gradients after backward() and before optimizer.step() to control gradient size.
Use torch.nn.utils.clip_grad_norm_ to clip gradients by their norm for stable training.
Avoid clipping gradients before backward() or after optimizer.step() as it has no effect.
Choose an appropriate max_norm value to prevent exploding gradients without slowing training.
Gradient clipping is especially useful in deep or recurrent neural networks to maintain training stability.