How to Use Learning Rate Scheduler in PyTorch for Better Training
In PyTorch, use a
torch.optim.lr_scheduler to adjust the learning rate during training. Create a scheduler object linked to your optimizer, then call scheduler.step() at each epoch or batch to update the learning rate automatically.Syntax
To use a learning rate scheduler in PyTorch, first create an optimizer, then create a scheduler by passing the optimizer and scheduler-specific parameters. During training, call scheduler.step() to update the learning rate.
- optimizer: Your optimizer instance (e.g., Adam, SGD).
- scheduler: A learning rate scheduler from
torch.optim.lr_scheduler. - step(): Method to update the learning rate, called each epoch or batch depending on scheduler type.
python
import torch import torch.optim as optim model = torch.nn.Linear(10, 2) optimizer = optim.SGD(model.parameters(), lr=0.1) # Create a scheduler that decreases LR by 10% every epoch scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9) for epoch in range(3): # Training code here optimizer.step() scheduler.step() # Update learning rate print(f"Epoch {epoch+1}, LR: {scheduler.get_last_lr()[0]:.5f}")
Output
Epoch 1, LR: 0.09000
Epoch 2, LR: 0.08100
Epoch 3, LR: 0.07290
Example
This example shows how to use StepLR scheduler with a simple linear model and SGD optimizer. The learning rate decreases by 10% after each epoch, and the updated learning rate is printed.
python
import torch import torch.nn as nn import torch.optim as optim # Simple model model = nn.Linear(5, 1) # Optimizer with initial learning rate 0.1 optimizer = optim.SGD(model.parameters(), lr=0.1) # Scheduler to reduce LR by 10% every epoch scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9) for epoch in range(5): # Dummy training step optimizer.zero_grad() inputs = torch.randn(10, 5) outputs = model(inputs) loss = outputs.mean() loss.backward() optimizer.step() # Update learning rate scheduler.step() # Print current learning rate print(f"Epoch {epoch+1}: Learning Rate = {scheduler.get_last_lr()[0]:.5f}")
Output
Epoch 1: Learning Rate = 0.09000
Epoch 2: Learning Rate = 0.08100
Epoch 3: Learning Rate = 0.07290
Epoch 4: Learning Rate = 0.06561
Epoch 5: Learning Rate = 0.05905
Common Pitfalls
- Not calling
scheduler.step()at the right time: For most schedulers, callscheduler.step()afteroptimizer.step()each epoch. - Using
scheduler.step()every batch when it expects epoch steps: Some schedulers expect step per epoch, others per batch. Check documentation. - Forgetting to pass optimizer to scheduler: Scheduler needs the optimizer to adjust its learning rate.
- Not checking learning rate updates: Use
scheduler.get_last_lr()to verify changes.
python
import torch.optim as optim optimizer = optim.SGD(model.parameters(), lr=0.1) # Wrong: calling scheduler.step() before optimizer.step() scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9) for epoch in range(3): # optimizer.step() should come first scheduler.step() # Wrong timing optimizer.step() # Correct order: for epoch in range(3): optimizer.step() scheduler.step()
Quick Reference
Here are common PyTorch learning rate schedulers:
| Scheduler | Description | When to call step() |
|---|---|---|
| StepLR | Reduces LR by gamma every fixed number of epochs | After each epoch |
| MultiStepLR | Reduces LR at specified epochs | After each epoch |
| ExponentialLR | Reduces LR exponentially every epoch | After each epoch |
| CosineAnnealingLR | LR follows cosine curve | After each epoch |
| ReduceLROnPlateau | Reduces LR when metric stops improving | After validation step (pass metric) |
Key Takeaways
Create a scheduler by passing your optimizer and parameters from torch.optim.lr_scheduler.
Call scheduler.step() at the correct time, usually after optimizer.step() each epoch.
Use scheduler.get_last_lr() to check the current learning rate during training.
Choose the scheduler type based on your training needs and learning rate decay strategy.
Be careful with schedulers like ReduceLROnPlateau that require metric input when stepping.