0
0
PyTorchml~5 mins

Weight decay (L2 regularization) in PyTorch

Choose your learning style9 modes available
Introduction
Weight decay helps prevent a model from memorizing training data by keeping its weights small and simple.
When your model performs well on training data but poorly on new data (overfitting).
When training deep neural networks that tend to have many parameters.
When you want to improve your model's ability to generalize to unseen examples.
When you want to add a penalty to large weights during training to keep the model simpler.
Syntax
PyTorch
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=decay_rate)
weight_decay is the L2 regularization strength; typical values are small like 0.0001 or 0.01.
It is added as a penalty to the loss to discourage large weights.
Examples
SGD optimizer with learning rate 0.1 and L2 regularization strength 0.01.
PyTorch
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, weight_decay=0.01)
Adam optimizer with a small weight decay to reduce overfitting.
PyTorch
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
Sample Model
This code trains a simple linear model to fit y=2x+1 using SGD optimizer with weight decay to keep weights small. It prints loss each epoch and final weights.
PyTorch
import torch
import torch.nn as nn
import torch.optim as optim

# Simple linear model
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(1, 1)
    def forward(self, x):
        return self.linear(x)

# Create model
model = SimpleModel()

# Dummy data: y = 2x + 1
x_train = torch.tensor([[1.0], [2.0], [3.0], [4.0]])
y_train = torch.tensor([[3.0], [5.0], [7.0], [9.0]])

# Loss function
criterion = nn.MSELoss()

# Optimizer with weight decay (L2 regularization)
optimizer = optim.SGD(model.parameters(), lr=0.1, weight_decay=0.01)

# Training loop
for epoch in range(10):
    optimizer.zero_grad()
    outputs = model(x_train)
    loss = criterion(outputs, y_train)
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

# Print final weights
for name, param in model.named_parameters():
    print(f"{name}: {param.data.flatten().tolist()}")
OutputSuccess
Important Notes
Weight decay is applied automatically by the optimizer during the update step.
It is different from manually adding L2 penalty to the loss but mathematically equivalent.
Too large weight decay can cause underfitting by making weights too small.
Summary
Weight decay adds a penalty to large weights to reduce overfitting.
In PyTorch, set weight_decay in the optimizer to enable L2 regularization.
It helps models generalize better by keeping weights small and simple.