Warmup strategies help the model start learning gently by slowly increasing the learning rate. This avoids big jumps that can confuse the model early on.
Warmup strategies in PyTorch
from torch.optim.lr_scheduler import LambdaLR # Define a warmup function def warmup_lambda(current_step): if current_step < warmup_steps: return float(current_step) / float(max(1, warmup_steps)) return 1.0 # Create optimizer optimizer = torch.optim.Adam(model.parameters(), lr=base_lr) # Create scheduler with warmup scheduler = LambdaLR(optimizer, lr_lambda=warmup_lambda)
The warmup function returns a multiplier for the learning rate.
During warmup steps, the multiplier grows from 0 to 1, then stays at 1.
def warmup_lambda(step): return min(1.0, step / 1000)
def warmup_lambda(step): if step < 500: return step / 500 else: return 1.0
scheduler = LambdaLR(optimizer, lr_lambda=lambda step: min(1.0, step / 2000))
This code trains a simple linear model with a warmup strategy for the learning rate over 5 steps. The learning rate starts at 0 and grows to 0.1 gradually, then stays constant.
import torch import torch.nn as nn import torch.optim as optim from torch.optim.lr_scheduler import LambdaLR # Simple model model = nn.Linear(10, 1) # Parameters base_lr = 0.1 warmup_steps = 5 # Optimizer optimizer = optim.SGD(model.parameters(), lr=base_lr) # Warmup function def warmup_lambda(step): if step < warmup_steps: return float(step) / float(max(1, warmup_steps)) return 1.0 # Scheduler scheduler = LambdaLR(optimizer, lr_lambda=warmup_lambda) # Dummy data inputs = torch.randn(10, 10) targets = torch.randn(10, 1) # Loss criterion = nn.MSELoss() print('Step | Learning Rate | Loss') for step in range(10): optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.step() scheduler.step() lr = optimizer.param_groups[0]['lr'] print(f'{step:4d} | {lr:.4f} | {loss.item():.4f}')
Warmup helps prevent the model from making large, unstable updates early in training.
You can combine warmup with other learning rate schedules for better results.
Adjust warmup_steps based on your dataset size and model complexity.
Warmup strategies gradually increase learning rate at the start of training.
This helps models learn smoothly and avoid unstable updates.
In PyTorch, LambdaLR with a custom function is a simple way to add warmup.