How to Use nn.BatchNorm2d in PyTorch: Syntax and Example
Use
nn.BatchNorm2d in PyTorch to normalize 2D feature maps by specifying the number of channels. It helps stabilize and speed up training by normalizing inputs across the batch dimension during convolutional neural network training.Syntax
The nn.BatchNorm2d layer is initialized with the number of feature channels it will normalize. It has optional parameters to control behavior during training and inference.
- num_features: Number of channels in the input tensor.
- eps: Small value to avoid division by zero (default 1e-5).
- momentum: Value for running mean/variance update (default 0.1).
- affine: If
True, the layer has learnable scale and shift parameters. - track_running_stats: If
True, tracks running mean and variance for evaluation.
python
nn.BatchNorm2d(num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True)
Example
This example shows how to create a BatchNorm2d layer and apply it to a random 4D tensor representing a batch of images with 3 channels. It prints the input and output shapes and verifies the normalization effect.
python
import torch import torch.nn as nn # Create BatchNorm2d for 3 channels batch_norm = nn.BatchNorm2d(num_features=3) # Random input: batch size 2, 3 channels, 4x4 image input_tensor = torch.randn(2, 3, 4, 4) # Apply batch normalization output_tensor = batch_norm(input_tensor) print('Input shape:', input_tensor.shape) print('Output shape:', output_tensor.shape) # Check mean and std per channel in output (approx 0 and 1 during training) mean = output_tensor.mean(dim=[0, 2, 3]) std = output_tensor.std(dim=[0, 2, 3]) print('Output mean per channel:', mean) print('Output std per channel:', std)
Output
Input shape: torch.Size([2, 3, 4, 4])
Output shape: torch.Size([2, 3, 4, 4])
Output mean per channel: tensor([-0.0006, 0.0002, 0.0003], grad_fn=<MeanBackward0>)
Output std per channel: tensor([1.0001, 1.0000, 1.0000], grad_fn=<StdBackward0>)
Common Pitfalls
- Using BatchNorm2d in evaluation mode without switching: Remember to call
model.eval()to use running statistics instead of batch statistics during inference. - Incorrect input shape: Input must be 4D tensor with shape
(batch_size, channels, height, width). Passing other shapes causes errors. - Not using
affine=Truewhen needed: If you want the layer to learn scale and shift, keepaffine=True. Otherwise, normalization might be too rigid.
python
import torch import torch.nn as nn batch_norm = nn.BatchNorm2d(3) input_tensor = torch.randn(2, 3, 4, 4) # Correct: switching to eval mode during inference batch_norm.eval() output_eval = batch_norm(input_tensor) # Uses running stats # Wrong input shape example try: wrong_input = torch.randn(2, 4, 4) # Missing channel dim batch_norm(wrong_input) except Exception as e: print('Error:', e)
Output
Error: expected 4D input (got 3D input)
Quick Reference
Tips for using nn.BatchNorm2d:
- Use
num_featuresequal to the number of channels in your input. - Call
model.train()during training andmodel.eval()during evaluation to switch batch norm behavior. - BatchNorm2d normalizes each channel independently over the batch and spatial dimensions.
- It helps reduce internal covariate shift and speeds up training convergence.
Key Takeaways
nn.BatchNorm2d normalizes 2D feature maps channel-wise to stabilize training.
Always use 4D input tensors with shape (batch, channels, height, width).
Switch between training and evaluation modes to use batch or running statistics.
Set affine=True to allow learnable scaling and shifting after normalization.
BatchNorm2d speeds up training and improves model performance in CNNs.