0
0
PyTorchml~5 mins

Warmup strategies in PyTorch

Choose your learning style9 modes available
Introduction

Warmup strategies help the model start learning gently by slowly increasing the learning rate. This avoids big jumps that can confuse the model early on.

When training a deep neural network from scratch to avoid unstable updates.
When using a large learning rate that might be too strong at the start.
When fine-tuning a pretrained model to adapt smoothly to new data.
When training on a new dataset that is very different from previous data.
When you notice training loss jumping or not improving at the beginning.
Syntax
PyTorch
from torch.optim.lr_scheduler import LambdaLR

# Define a warmup function
def warmup_lambda(current_step):
    if current_step < warmup_steps:
        return float(current_step) / float(max(1, warmup_steps))
    return 1.0

# Create optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=base_lr)

# Create scheduler with warmup
scheduler = LambdaLR(optimizer, lr_lambda=warmup_lambda)

The warmup function returns a multiplier for the learning rate.

During warmup steps, the multiplier grows from 0 to 1, then stays at 1.

Examples
This linearly increases learning rate from 0 to full over 1000 steps.
PyTorch
def warmup_lambda(step):
    return min(1.0, step / 1000)
Warmup for 500 steps, then keep learning rate constant.
PyTorch
def warmup_lambda(step):
    if step < 500:
        return step / 500
    else:
        return 1.0
Using a lambda function directly to warm up over 2000 steps.
PyTorch
scheduler = LambdaLR(optimizer, lr_lambda=lambda step: min(1.0, step / 2000))
Sample Model

This code trains a simple linear model with a warmup strategy for the learning rate over 5 steps. The learning rate starts at 0 and grows to 0.1 gradually, then stays constant.

PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR

# Simple model
model = nn.Linear(10, 1)

# Parameters
base_lr = 0.1
warmup_steps = 5

# Optimizer
optimizer = optim.SGD(model.parameters(), lr=base_lr)

# Warmup function

def warmup_lambda(step):
    if step < warmup_steps:
        return float(step) / float(max(1, warmup_steps))
    return 1.0

# Scheduler
scheduler = LambdaLR(optimizer, lr_lambda=warmup_lambda)

# Dummy data
inputs = torch.randn(10, 10)
targets = torch.randn(10, 1)

# Loss
criterion = nn.MSELoss()

print('Step | Learning Rate | Loss')
for step in range(10):
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()
    scheduler.step()
    lr = optimizer.param_groups[0]['lr']
    print(f'{step:4d} | {lr:.4f}        | {loss.item():.4f}')
OutputSuccess
Important Notes

Warmup helps prevent the model from making large, unstable updates early in training.

You can combine warmup with other learning rate schedules for better results.

Adjust warmup_steps based on your dataset size and model complexity.

Summary

Warmup strategies gradually increase learning rate at the start of training.

This helps models learn smoothly and avoid unstable updates.

In PyTorch, LambdaLR with a custom function is a simple way to add warmup.