0
0
PyTorchml~20 mins

Why CNNs detect spatial patterns in PyTorch - Experiment to Prove It

Choose your learning style9 modes available
Experiment - Why CNNs detect spatial patterns
Problem:We want to understand why Convolutional Neural Networks (CNNs) are good at detecting spatial patterns in images. Currently, a simple CNN model is trained on a small image dataset, but it does not clearly show how spatial features are captured.
Current Metrics:Training accuracy: 85%, Validation accuracy: 80%, Loss: 0.45
Issue:The model works but it is not clear how spatial patterns are detected. The learner needs to see how convolution layers focus on local spatial features.
Your Task
Modify the CNN model to visualize and explain how convolutional layers detect spatial patterns in images. Show intermediate feature maps after convolution layers.
Use PyTorch framework
Keep the model simple (2 convolutional layers)
Use a small image dataset like CIFAR-10 or a subset
Do not change the dataset or overall training procedure
Hint 1
Hint 2
Hint 3
Solution
PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

# Define a simple CNN
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, kernel_size=3, padding=1)  # 3 input channels, 6 filters
        self.conv2 = nn.Conv2d(6, 12, kernel_size=3, padding=1) # 6 input channels, 12 filters
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(12 * 8 * 8, 120)
        self.fc2 = nn.Linear(120, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))  # After conv1 + relu
        self.feature_maps1 = x.detach()       # Save feature maps after first conv
        x = self.pool(x)  # Pooling after conv1
        x = F.relu(self.conv2(x))  # After conv2 + relu
        self.feature_maps2 = x.detach()       # Save feature maps after second conv
        x = self.pool(x)  # Pooling after conv2
        x = x.view(-1, 12 * 8 * 8)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Load CIFAR-10 dataset (small subset for speed)
transform = transforms.Compose([transforms.ToTensor()])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=0)

# Initialize model, loss, optimizer
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Train for 1 epoch to keep it simple
for images, labels in trainloader:
    optimizer.zero_grad()
    outputs = model(images)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    break  # Only one batch for demonstration

# Visualize feature maps from first conv layer
feature_maps = model.feature_maps1[0]  # Take first image in batch
num_maps = feature_maps.shape[0]
fig, axes = plt.subplots(1, num_maps, figsize=(15, 5))
for i in range(num_maps):
    axes[i].imshow(feature_maps[i].cpu(), cmap='gray')
    axes[i].axis('off')
plt.suptitle('Feature maps after first convolution layer')
plt.show()

# Visualize feature maps from second conv layer
feature_maps2 = model.feature_maps2[0]
num_maps2 = feature_maps2.shape[0]
fig2, axes2 = plt.subplots(1, num_maps2, figsize=(20, 5))
for i in range(num_maps2):
    axes2[i].imshow(feature_maps2[i].cpu(), cmap='gray')
    axes2[i].axis('off')
plt.suptitle('Feature maps after second convolution layer')
plt.show()
Modified forward method to save feature maps before pooling layers for clearer spatial pattern visualization
Kept model simple with two convolution layers and pooling
Trained on one batch for quick demonstration
Visualized feature maps as grayscale images to show spatial patterns detected
Results Interpretation

Before: Model trained but no insight into spatial pattern detection.

After: Visual feature maps show how convolution filters highlight edges, textures, and shapes in local image regions.

Convolutional layers scan small image patches with filters that detect simple spatial features like edges. These features combine in deeper layers to recognize complex patterns. Visualizing feature maps helps understand this spatial detection ability of CNNs.
Bonus Experiment
Try adding more convolutional layers and visualize how feature maps evolve to detect more complex spatial patterns.
💡 Hint
Deeper layers capture more abstract features. Visualize feature maps after each conv layer to see this progression.