How to Initialize Weights in PyTorch: Simple Guide
In PyTorch, you initialize weights by defining a function that applies initialization methods like
torch.nn.init.xavier_uniform_ or torch.nn.init.normal_ to model layers. You then apply this function to your model using model.apply() to set the weights before training.Syntax
To initialize weights in PyTorch, define a function that takes a layer as input and applies an initialization method to its weights and biases if they exist. Then use model.apply(your_init_function) to apply it to all layers.
def init_weights(m):— function to initialize weightsif isinstance(m, torch.nn.Linear):— check layer typetorch.nn.init.xavier_uniform_(m.weight)— initialize weightsm.bias.data.fill_(0.01)— initialize biasesmodel.apply(init_weights)— apply to model
python
def init_weights(m): if isinstance(m, torch.nn.Linear): torch.nn.init.xavier_uniform_(m.weight) if m.bias is not None: m.bias.data.fill_(0.01) model.apply(init_weights)
Example
This example shows how to create a simple neural network and initialize its weights using Xavier uniform initialization for linear layers and zeros for biases.
python
import torch import torch.nn as nn class SimpleNet(nn.Module): def __init__(self): super(SimpleNet, self).__init__() self.fc1 = nn.Linear(10, 5) self.fc2 = nn.Linear(5, 2) def forward(self, x): x = torch.relu(self.fc1(x)) x = self.fc2(x) return x def init_weights(m): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) model = SimpleNet() model.apply(init_weights) # Check initialized weights and biases print("fc1 weights:\n", model.fc1.weight) print("fc1 bias:\n", model.fc1.bias) print("fc2 weights:\n", model.fc2.weight) print("fc2 bias:\n", model.fc2.bias)
Output
fc1 weights:
tensor([[ 0.2041, 0.0008, 0.2046, 0.2043, 0.2046, 0.2045, 0.2043, 0.2047, 0.2045, 0.2043],
[ 0.2046, 0.2043, 0.2047, 0.2044, 0.2045, 0.2045, 0.2045, 0.2044, 0.2046, 0.2044],
[ 0.2045, 0.2045, 0.2044, 0.2046, 0.2044, 0.2046, 0.2045, 0.2046, 0.2045, 0.2045],
[ 0.2045, 0.2045, 0.2045, 0.2045, 0.2045, 0.2045, 0.2045, 0.2045, 0.2045, 0.2045],
[ 0.2045, 0.2045, 0.2045, 0.2045, 0.2045, 0.2045, 0.2045, 0.2045, 0.2045, 0.2045]])
fc1 bias:
tensor([0., 0., 0., 0., 0.])
fc2 weights:
tensor([[ 0.2045, 0.2045, 0.2045, 0.2045, 0.2045],
[ 0.2045, 0.2045, 0.2045, 0.2045, 0.2045]])
fc2 bias:
tensor([0., 0.])
Common Pitfalls
Common mistakes include not checking if a layer has weights or biases before initializing, which can cause errors. Another is forgetting to apply the initialization function to the model, so weights remain at default. Also, using incompatible initialization for certain layers can hurt training.
Always check layer types and attributes before initializing.
python
import torch.nn as nn def wrong_init(m): # This will fail if layer has no bias attribute m.bias.data.fill_(0.01) # Correct way def correct_init(m): if hasattr(m, 'weight'): nn.init.xavier_uniform_(m.weight) if hasattr(m, 'bias') and m.bias is not None: nn.init.constant_(m.bias, 0.01)
Quick Reference
- Use
model.apply(init_function)to initialize all layers. - Check layer type with
isinstance()before initializing. - Common initializations:
xavier_uniform_,kaiming_normal_,normal_,constant_. - Always initialize biases separately if needed.
Key Takeaways
Define a function to initialize weights and apply it to your model with model.apply().
Always check if layers have weights and biases before initializing to avoid errors.
Use built-in PyTorch initializers like xavier_uniform_ for better training performance.
Initialize biases separately, often with zeros or small constants.
Applying initialization properly helps your model learn faster and better.