0
0
PytorchHow-ToBeginner · 4 min read

How to Use Weight Decay in PyTorch for Regularization

In PyTorch, you use weight_decay as a parameter in optimizers like torch.optim.SGD or torch.optim.Adam to apply L2 regularization. Set weight_decay to a small positive value (e.g., 0.01) when creating the optimizer to penalize large weights during training.
📐

Syntax

The weight_decay parameter is passed when you create an optimizer in PyTorch. It controls the strength of L2 regularization, which helps prevent overfitting by penalizing large weights.

  • optimizer_class: The optimizer type, e.g., torch.optim.SGD or torch.optim.Adam.
  • params: The model parameters to optimize.
  • lr: Learning rate for the optimizer.
  • weight_decay: The L2 penalty coefficient (float, usually small like 0.01).
python
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, weight_decay=0.01)
💻

Example

This example shows how to create a simple linear model and apply weight decay using the SGD optimizer. It trains the model on dummy data and prints the loss to demonstrate training with weight decay.

python
import torch
import torch.nn as nn

# Simple linear model
model = nn.Linear(2, 1)

# Dummy data
inputs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
targets = torch.tensor([[1.0], [2.0], [3.0]])

# Optimizer with weight decay
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, weight_decay=0.01)

# Loss function
criterion = nn.MSELoss()

# Training loop
for epoch in range(5):
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")
Output
Epoch 1, Loss: 4.9643 Epoch 2, Loss: 0.2382 Epoch 3, Loss: 0.0311 Epoch 4, Loss: 0.0053 Epoch 5, Loss: 0.0011
⚠️

Common Pitfalls

Common mistakes when using weight decay include:

  • Setting weight_decay too high, which can cause underfitting by overly penalizing weights.
  • Confusing weight_decay with learning rate; they control different things.
  • Applying weight decay to bias terms or batch norm parameters unintentionally, which is usually not recommended.

To avoid the last issue, you can separate parameters that should not have weight decay.

python
import torch

# Wrong: applying weight decay to all parameters including biases
optimizer_wrong = torch.optim.SGD(model.parameters(), lr=0.1, weight_decay=0.1)

# Right: exclude biases and batchnorm from weight decay
params_with_decay = []
params_without_decay = []
for name, param in model.named_parameters():
    if 'bias' in name or 'bn' in name:
        params_without_decay.append(param)
    else:
        params_with_decay.append(param)

optimizer_right = torch.optim.SGD([
    {'params': params_with_decay, 'weight_decay': 0.01},
    {'params': params_without_decay, 'weight_decay': 0.0}
], lr=0.1)
📊

Quick Reference

Summary tips for using weight decay in PyTorch:

  • Use weight_decay in optimizer to add L2 regularization.
  • Typical values range from 0.0001 to 0.01 depending on the model.
  • Exclude biases and batch norm parameters from weight decay for better results.
  • Weight decay is different from learning rate; tune separately.

Key Takeaways

Set weight_decay in the optimizer to apply L2 regularization in PyTorch.
Use small values like 0.01 to prevent overfitting without hurting learning.
Exclude biases and batch norm parameters from weight decay to avoid training issues.
Weight decay is not the same as learning rate; adjust them independently.
Proper use of weight decay improves model generalization and reduces overfitting.