0
0
PyTorchml~5 mins

Batch normalization (nn.BatchNorm) in PyTorch

Choose your learning style9 modes available
Introduction

Batch normalization helps a neural network learn faster and better by keeping data balanced inside the network.

When training deep neural networks to speed up learning.
When you want to reduce the chance of the model getting stuck during training.
When you want the model to be less sensitive to the starting values.
When you want to improve the model's accuracy on new data.
When you want to stabilize the training process.
Syntax
PyTorch
torch.nn.BatchNorm1d(num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True)

# For 2D data (images), use BatchNorm2d
# For 3D data (videos), use BatchNorm3d

num_features is the number of features or channels in your data.

eps is a small number to avoid division by zero.

Examples
Creates batch normalization for 10 features in a 1D input like a vector.
PyTorch
bn = torch.nn.BatchNorm1d(10)
# For 10 features in a 1D input
Creates batch normalization for 3 channels in image data (like RGB images).
PyTorch
bn2d = torch.nn.BatchNorm2d(3)
# For 3 channels in image data
Adjusts how fast the running mean and variance update during training.
PyTorch
bn = torch.nn.BatchNorm1d(5, momentum=0.05)
# Using a smaller momentum for running stats
Sample Model

This example shows how batch normalization adjusts the input data to have a mean close to 0 and variance close to 1 for each feature across the batch.

PyTorch
import torch
import torch.nn as nn

# Create batch norm for 4 features
batch_norm = nn.BatchNorm1d(4)

# Sample input: batch of 3 samples, each with 4 features
input_data = torch.tensor([[1.0, 2.0, 3.0, 4.0],
                           [2.0, 3.0, 4.0, 5.0],
                           [3.0, 4.0, 5.0, 6.0]])

# Apply batch normalization
output = batch_norm(input_data)

print("Input:")
print(input_data)
print("\nOutput after BatchNorm:")
print(output)

# Check running mean and variance
print("\nRunning mean:", batch_norm.running_mean)
print("Running var:", batch_norm.running_var)
OutputSuccess
Important Notes

BatchNorm uses the batch's mean and variance during training, but uses running averages during evaluation.

Remember to switch your model to evaluation mode with model.eval() when testing.

BatchNorm layers have learnable parameters to scale and shift the normalized data.

Summary

Batch normalization keeps data balanced inside the network to help learning.

It normalizes each feature using batch statistics during training.

It improves speed, stability, and accuracy of neural networks.