How to Use BatchNorm2d in PyTorch: Syntax and Example
Use
torch.nn.BatchNorm2d to normalize 2D convolutional outputs by specifying the number of feature channels. Initialize it with the number of channels, then apply it to your convolutional layer outputs during training to stabilize and speed up learning.Syntax
The BatchNorm2d layer is initialized with the number of feature channels it will normalize. It has optional parameters to control behavior like eps (small number to avoid division by zero), momentum (for running statistics), and whether to learn affine parameters.
Typical usage:
num_features: Number of channels from convolution output.eps: Small float added to variance for stability (default 1e-5).momentum: Momentum for running mean/variance (default 0.1).affine: If True, layer learns scale and shift parameters.
python
torch.nn.BatchNorm2d(num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True)
Example
This example shows how to create a simple convolutional layer followed by BatchNorm2d. It runs a random input tensor through both layers and prints the output shape and values.
python
import torch import torch.nn as nn # Define a simple model with Conv2d and BatchNorm2d class SimpleCNN(nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, padding=1) self.bn = nn.BatchNorm2d(num_features=6) def forward(self, x): x = self.conv(x) x = self.bn(x) return x # Create model and input model = SimpleCNN() input_tensor = torch.randn(1, 3, 4, 4) # batch size 1, 3 channels, 4x4 image # Forward pass output = model(input_tensor) print("Output shape:", output.shape) print("Output tensor:", output)
Output
Output shape: torch.Size([1, 6, 4, 4])
Output tensor: tensor([[[[ 0.1234, -0.5678, ...], ... ]]], grad_fn=<NativeBatchNormBackward0>)
Common Pitfalls
- Not matching
num_featuresto the convolution output channels causes errors. - Using
BatchNorm2din evaluation mode without callingmodel.eval()leads to wrong normalization because running stats are not used. - Applying batch normalization before convolution instead of after is incorrect.
- For very small batch sizes, batch normalization may be unstable.
python
import torch import torch.nn as nn # Wrong: num_features does not match conv output channels try: wrong_bn = nn.BatchNorm2d(num_features=3) # conv outputs 6 channels x = torch.randn(1, 6, 4, 4) wrong_bn(x) except Exception as e: print("Error:", e) # Right: num_features matches conv output channels correct_bn = nn.BatchNorm2d(num_features=6) x = torch.randn(1, 6, 4, 4) output = correct_bn(x) print("BatchNorm output shape:", output.shape)
Output
Error: Expected 3 channels but got input with 6 channels
BatchNorm output shape: torch.Size([1, 6, 4, 4])
Quick Reference
BatchNorm2d Cheat Sheet:
num_features: Must equal conv output channels.- Use
model.train()for training mode andmodel.eval()for evaluation mode. - BatchNorm normalizes each channel independently over batch and spatial dimensions.
- Helps speed up training and improve stability.
Key Takeaways
Initialize BatchNorm2d with the number of output channels from your convolution layer.
Always apply BatchNorm2d after convolution and before activation functions.
Switch your model to eval mode with model.eval() to use running statistics during inference.
BatchNorm2d normalizes each channel independently across batch and spatial dimensions.
Incorrect num_features or mode settings cause runtime errors or poor model performance.