0
0
PyTorchml~20 mins

Feature map visualization in PyTorch - ML Experiment: Train & Evaluate

Choose your learning style9 modes available
Experiment - Feature map visualization
Problem:You have trained a convolutional neural network (CNN) on image data. You want to understand what features the network learns by visualizing the feature maps (outputs of convolutional layers) for a sample input image.
Current Metrics:Model trains with 85% accuracy on training data and 80% on validation data. No visualization of feature maps has been done yet.
Issue:Without visualizing feature maps, it is hard to interpret what the CNN focuses on in the images. This limits understanding and debugging of the model.
Your Task
Visualize the feature maps of the first convolutional layer for a given input image to understand what features the CNN extracts.
Use PyTorch framework.
Do not modify the trained model weights.
Visualize only the first convolutional layer's output feature maps.
Use matplotlib for plotting.
Hint 1
Hint 2
Hint 3
Hint 4
Solution
PyTorch
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt

# Define a simple CNN model
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, kernel_size=5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Load pretrained model (simulate training by loading random weights here)
model = SimpleCNN()
model.eval()

# Prepare a sample image from CIFAR10 validation set
transform = transforms.Compose([
    transforms.ToTensor(),
])

val_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=True)

# Get one sample image and label
sample_img, sample_label = next(iter(val_loader))

# Hook to capture feature maps
feature_maps = []

def hook_fn(module, input, output):
    feature_maps.append(output.detach())

# Register hook on first conv layer
hook = model.conv1.register_forward_hook(hook_fn)

# Forward pass
_ = model(sample_img)

# Remove hook
hook.remove()

# feature_maps[0] shape: [1, 6, 28, 28] (batch, channels, height, width)
fm = feature_maps[0][0]  # remove batch dimension

# Plot feature maps
num_maps = fm.shape[0]
cols = 3
rows = (num_maps + cols - 1) // cols
plt.figure(figsize=(cols * 3, rows * 3))
for i in range(num_maps):
    plt.subplot(rows, cols, i + 1)
    # Normalize to 0-1 for visualization
    fmap = fm[i]
    fmap = (fmap - fmap.min()) / (fmap.max() - fmap.min() + 1e-5)
    plt.imshow(fmap.cpu(), cmap='gray')
    plt.axis('off')
    plt.title(f'Feature map {i+1}')
plt.tight_layout()
plt.show()
Added a forward hook on the first convolutional layer to capture its output feature maps.
Extracted feature maps for a sample input image from the validation set.
Normalized each feature map to the 0-1 range for clear visualization.
Plotted all feature maps in a grid using matplotlib.
Added a small epsilon (1e-5) to denominator in normalization to avoid division by zero.
Results Interpretation

Before: No insight into what the CNN learns internally.

After: Visualized 6 feature maps from the first convolutional layer, each highlighting different edges and textures in the input image.

Visualizing feature maps helps understand what patterns the CNN detects early on. This aids in interpreting and debugging the model.
Bonus Experiment
Visualize feature maps from the second convolutional layer and compare them with the first layer's maps.
💡 Hint
Register a hook on the second conv layer and plot its feature maps similarly. Notice how deeper layers capture more complex features.