import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
# Define a simple CNN
class SimpleCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, kernel_size=3, padding=1) # 3 input channels, 6 filters
self.conv2 = nn.Conv2d(6, 12, kernel_size=3, padding=1) # 6 input channels, 12 filters
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(12 * 8 * 8, 120)
self.fc2 = nn.Linear(120, 10)
def forward(self, x):
x = F.relu(self.conv1(x)) # After conv1 + relu
self.feature_maps1 = x.detach() # Save feature maps after first conv
x = self.pool(x) # Pooling after conv1
x = F.relu(self.conv2(x)) # After conv2 + relu
self.feature_maps2 = x.detach() # Save feature maps after second conv
x = self.pool(x) # Pooling after conv2
x = x.view(-1, 12 * 8 * 8)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
# Load CIFAR-10 dataset (small subset for speed)
transform = transforms.Compose([transforms.ToTensor()])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=0)
# Initialize model, loss, optimizer
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# Train for 1 epoch to keep it simple
for images, labels in trainloader:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
break # Only one batch for demonstration
# Visualize feature maps from first conv layer
feature_maps = model.feature_maps1[0] # Take first image in batch
num_maps = feature_maps.shape[0]
fig, axes = plt.subplots(1, num_maps, figsize=(15, 5))
for i in range(num_maps):
axes[i].imshow(feature_maps[i].cpu(), cmap='gray')
axes[i].axis('off')
plt.suptitle('Feature maps after first convolution layer')
plt.show()
# Visualize feature maps from second conv layer
feature_maps2 = model.feature_maps2[0]
num_maps2 = feature_maps2.shape[0]
fig2, axes2 = plt.subplots(1, num_maps2, figsize=(20, 5))
for i in range(num_maps2):
axes2[i].imshow(feature_maps2[i].cpu(), cmap='gray')
axes2[i].axis('off')
plt.suptitle('Feature maps after second convolution layer')
plt.show()