0
0
PyTorchml~5 mins

Feature map visualization in PyTorch

Choose your learning style9 modes available
Introduction

Feature map visualization helps us see what a neural network learns inside. It shows which parts of the input the model focuses on.

To understand how a convolutional neural network processes an image.
To check if the model is focusing on the right parts of the input.
To debug or improve model design by seeing intermediate outputs.
To explain model decisions visually to others.
To learn how different layers transform the input data.
Syntax
PyTorch
feature_maps = model.layer(input_tensor)
# feature_maps shape: (batch_size, channels, height, width)

# To visualize, convert feature_maps to numpy and plot each channel as an image

Feature maps are the outputs of convolutional layers.

They have 4 dimensions: batch size, channels, height, and width.

Examples
This gets the feature maps from the first convolutional layer.
PyTorch
feature_maps = model.conv1(input_image)
print(feature_maps.shape)
This plots each channel of the feature map as a grayscale image.
PyTorch
import matplotlib.pyplot as plt

for i in range(feature_maps.shape[1]):
    plt.subplot(1, feature_maps.shape[1], i+1)
    plt.imshow(feature_maps[0, i].detach().cpu(), cmap='gray')
    plt.axis('off')
plt.show()
Sample Model

This code creates a simple CNN with one convolutional layer. It passes a random image through the model, gets the feature maps, prints their shape, and shows each channel as a small image.

PyTorch
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

# Simple CNN model
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 4, kernel_size=3, padding=1)
    def forward(self, x):
        x = self.conv1(x)
        return x

# Create model and dummy input
model = SimpleCNN()
input_tensor = torch.randn(1, 1, 28, 28)  # batch=1, channel=1, 28x28 image

# Get feature maps from conv1
feature_maps = model(input_tensor)

# Print shape
print(f"Feature maps shape: {feature_maps.shape}")

# Plot feature maps
fig, axs = plt.subplots(1, feature_maps.shape[1], figsize=(12,3))
for i in range(feature_maps.shape[1]):
    axs[i].imshow(feature_maps[0, i].detach().cpu(), cmap='gray')
    axs[i].axis('off')
    axs[i].set_title(f'Channel {i+1}')
plt.show()
OutputSuccess
Important Notes

Detach the feature maps from the computation graph before plotting to avoid memory issues.

Use cpu() if your tensors are on GPU before converting to numpy or plotting.

Feature maps show patterns detected by filters, like edges or textures.

Summary

Feature maps are outputs of convolution layers showing learned patterns.

Visualizing them helps understand and debug CNNs.

Plot each channel as an image to see what the model focuses on.