Experiment - Batch normalization (nn.BatchNorm)
Problem:You are training a neural network on the MNIST dataset to classify handwritten digits. The current model uses simple linear layers with ReLU activations but no batch normalization.
Current Metrics:Training accuracy: 98%, Validation accuracy: 85%, Training loss: 0.05, Validation loss: 0.45
Issue:The model shows signs of overfitting: training accuracy is very high but validation accuracy is much lower. The validation loss is also significantly higher than training loss.