0
0
PytorchHow-ToBeginner · 4 min read

How to Use Learning Rate Scheduler in PyTorch for Better Training

In PyTorch, use a torch.optim.lr_scheduler to adjust the learning rate during training. Create a scheduler object linked to your optimizer, then call scheduler.step() at each epoch or batch to update the learning rate automatically.
📐

Syntax

To use a learning rate scheduler in PyTorch, first create an optimizer, then create a scheduler by passing the optimizer and scheduler-specific parameters. During training, call scheduler.step() to update the learning rate.

  • optimizer: Your optimizer instance (e.g., Adam, SGD).
  • scheduler: A learning rate scheduler from torch.optim.lr_scheduler.
  • step(): Method to update the learning rate, called each epoch or batch depending on scheduler type.
python
import torch
import torch.optim as optim

model = torch.nn.Linear(10, 2)
optimizer = optim.SGD(model.parameters(), lr=0.1)

# Create a scheduler that decreases LR by 10% every epoch
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)

for epoch in range(3):
    # Training code here
    optimizer.step()
    scheduler.step()  # Update learning rate
    print(f"Epoch {epoch+1}, LR: {scheduler.get_last_lr()[0]:.5f}")
Output
Epoch 1, LR: 0.09000 Epoch 2, LR: 0.08100 Epoch 3, LR: 0.07290
💻

Example

This example shows how to use StepLR scheduler with a simple linear model and SGD optimizer. The learning rate decreases by 10% after each epoch, and the updated learning rate is printed.

python
import torch
import torch.nn as nn
import torch.optim as optim

# Simple model
model = nn.Linear(5, 1)

# Optimizer with initial learning rate 0.1
optimizer = optim.SGD(model.parameters(), lr=0.1)

# Scheduler to reduce LR by 10% every epoch
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)

for epoch in range(5):
    # Dummy training step
    optimizer.zero_grad()
    inputs = torch.randn(10, 5)
    outputs = model(inputs)
    loss = outputs.mean()
    loss.backward()
    optimizer.step()

    # Update learning rate
    scheduler.step()

    # Print current learning rate
    print(f"Epoch {epoch+1}: Learning Rate = {scheduler.get_last_lr()[0]:.5f}")
Output
Epoch 1: Learning Rate = 0.09000 Epoch 2: Learning Rate = 0.08100 Epoch 3: Learning Rate = 0.07290 Epoch 4: Learning Rate = 0.06561 Epoch 5: Learning Rate = 0.05905
⚠️

Common Pitfalls

  • Not calling scheduler.step() at the right time: For most schedulers, call scheduler.step() after optimizer.step() each epoch.
  • Using scheduler.step() every batch when it expects epoch steps: Some schedulers expect step per epoch, others per batch. Check documentation.
  • Forgetting to pass optimizer to scheduler: Scheduler needs the optimizer to adjust its learning rate.
  • Not checking learning rate updates: Use scheduler.get_last_lr() to verify changes.
python
import torch.optim as optim

optimizer = optim.SGD(model.parameters(), lr=0.1)

# Wrong: calling scheduler.step() before optimizer.step()
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)

for epoch in range(3):
    # optimizer.step() should come first
    scheduler.step()  # Wrong timing
    optimizer.step()

# Correct order:
for epoch in range(3):
    optimizer.step()
    scheduler.step()
📊

Quick Reference

Here are common PyTorch learning rate schedulers:

SchedulerDescriptionWhen to call step()
StepLRReduces LR by gamma every fixed number of epochsAfter each epoch
MultiStepLRReduces LR at specified epochsAfter each epoch
ExponentialLRReduces LR exponentially every epochAfter each epoch
CosineAnnealingLRLR follows cosine curveAfter each epoch
ReduceLROnPlateauReduces LR when metric stops improvingAfter validation step (pass metric)

Key Takeaways

Create a scheduler by passing your optimizer and parameters from torch.optim.lr_scheduler.
Call scheduler.step() at the correct time, usually after optimizer.step() each epoch.
Use scheduler.get_last_lr() to check the current learning rate during training.
Choose the scheduler type based on your training needs and learning rate decay strategy.
Be careful with schedulers like ReduceLROnPlateau that require metric input when stepping.