Feature map visualization helps us see what a neural network learns inside. It shows which parts of the input the model focuses on.
0
0
Feature map visualization in PyTorch
Introduction
To understand how a convolutional neural network processes an image.
To check if the model is focusing on the right parts of the input.
To debug or improve model design by seeing intermediate outputs.
To explain model decisions visually to others.
To learn how different layers transform the input data.
Syntax
PyTorch
feature_maps = model.layer(input_tensor) # feature_maps shape: (batch_size, channels, height, width) # To visualize, convert feature_maps to numpy and plot each channel as an image
Feature maps are the outputs of convolutional layers.
They have 4 dimensions: batch size, channels, height, and width.
Examples
This gets the feature maps from the first convolutional layer.
PyTorch
feature_maps = model.conv1(input_image)
print(feature_maps.shape)This plots each channel of the feature map as a grayscale image.
PyTorch
import matplotlib.pyplot as plt for i in range(feature_maps.shape[1]): plt.subplot(1, feature_maps.shape[1], i+1) plt.imshow(feature_maps[0, i].detach().cpu(), cmap='gray') plt.axis('off') plt.show()
Sample Model
This code creates a simple CNN with one convolutional layer. It passes a random image through the model, gets the feature maps, prints their shape, and shows each channel as a small image.
PyTorch
import torch import torch.nn as nn import matplotlib.pyplot as plt # Simple CNN model class SimpleCNN(nn.Module): def __init__(self): super(SimpleCNN, self).__init__() self.conv1 = nn.Conv2d(1, 4, kernel_size=3, padding=1) def forward(self, x): x = self.conv1(x) return x # Create model and dummy input model = SimpleCNN() input_tensor = torch.randn(1, 1, 28, 28) # batch=1, channel=1, 28x28 image # Get feature maps from conv1 feature_maps = model(input_tensor) # Print shape print(f"Feature maps shape: {feature_maps.shape}") # Plot feature maps fig, axs = plt.subplots(1, feature_maps.shape[1], figsize=(12,3)) for i in range(feature_maps.shape[1]): axs[i].imshow(feature_maps[0, i].detach().cpu(), cmap='gray') axs[i].axis('off') axs[i].set_title(f'Channel {i+1}') plt.show()
OutputSuccess
Important Notes
Detach the feature maps from the computation graph before plotting to avoid memory issues.
Use cpu() if your tensors are on GPU before converting to numpy or plotting.
Feature maps show patterns detected by filters, like edges or textures.
Summary
Feature maps are outputs of convolution layers showing learned patterns.
Visualizing them helps understand and debug CNNs.
Plot each channel as an image to see what the model focuses on.