import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
# Define a simple CNN model
class SimpleCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, kernel_size=5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
# Load pretrained model (simulate training by loading random weights here)
model = SimpleCNN()
model.eval()
# Prepare a sample image from CIFAR10 validation set
transform = transforms.Compose([
transforms.ToTensor(),
])
val_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=True)
# Get one sample image and label
sample_img, sample_label = next(iter(val_loader))
# Hook to capture feature maps
feature_maps = []
def hook_fn(module, input, output):
feature_maps.append(output.detach())
# Register hook on first conv layer
hook = model.conv1.register_forward_hook(hook_fn)
# Forward pass
_ = model(sample_img)
# Remove hook
hook.remove()
# feature_maps[0] shape: [1, 6, 28, 28] (batch, channels, height, width)
fm = feature_maps[0][0] # remove batch dimension
# Plot feature maps
num_maps = fm.shape[0]
cols = 3
rows = (num_maps + cols - 1) // cols
plt.figure(figsize=(cols * 3, rows * 3))
for i in range(num_maps):
plt.subplot(rows, cols, i + 1)
# Normalize to 0-1 for visualization
fmap = fm[i]
fmap = (fmap - fmap.min()) / (fmap.max() - fmap.min() + 1e-5)
plt.imshow(fmap.cpu(), cmap='gray')
plt.axis('off')
plt.title(f'Feature map {i+1}')
plt.tight_layout()
plt.show()