0
0
PytorchHow-ToBeginner · 4 min read

How to Build a CNN in PyTorch: Simple Guide with Example

To build a CNN in PyTorch, define a class inheriting from torch.nn.Module and create convolutional layers using nn.Conv2d, activation functions like nn.ReLU, and pooling layers such as nn.MaxPool2d. Then implement the forward method to specify how data flows through these layers.
📐

Syntax

Here is the basic syntax to define a CNN in PyTorch:

  • Import torch and nn: Use import torch and import torch.nn as nn.
  • Create a class: Inherit from nn.Module to define your CNN.
  • Define layers: Use nn.Conv2d for convolution, nn.ReLU for activation, and nn.MaxPool2d for pooling.
  • Forward method: Define how input data passes through layers.
python
import torch
import torch.nn as nn

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc = nn.Linear(16 * 14 * 14, 10)  # assuming input 28x28

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.pool(x)
        x = x.view(x.size(0), -1)  # flatten
        x = self.fc(x)
        return x
💻

Example

This example shows a simple CNN for classifying 28x28 grayscale images (like MNIST digits). It has one convolutional layer, ReLU activation, max pooling, and a fully connected layer for output.

python
import torch
import torch.nn as nn

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)
        self.fc = nn.Linear(16 * 14 * 14, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# Create model instance
model = SimpleCNN()

# Create dummy input: batch size 1, 1 channel, 28x28 image
input_tensor = torch.randn(1, 1, 28, 28)

# Forward pass
output = model(input_tensor)
print(output)
Output
tensor([[ 0.0951, 0.0347, -0.0345, 0.0423, 0.0427, 0.0272, 0.0319, 0.0123, 0.0227, -0.0027]], grad_fn=<AddmmBackward0>)
⚠️

Common Pitfalls

  • Wrong input shape: CNN expects input as (batch_size, channels, height, width). Forgetting channels or batch dimension causes errors.
  • Flattening incorrectly: Use x.view(x.size(0), -1) to flatten before fully connected layers.
  • Missing activation: Without activation functions like ReLU, the model can't learn complex patterns.
  • Output size mismatch: Make sure the input size to the fully connected layer matches the flattened feature size after convolutions and pooling.
python
import torch
import torch.nn as nn

# Wrong flattening example
class WrongCNN(nn.Module):
    def __init__(self):
        super(WrongCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc = nn.Linear(16 * 14 * 14, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool(x)
        # Incorrect flattening - missing batch size
        x = x.view(-1)  # This will cause error
        x = self.fc(x)
        return x

# Correct flattening
class CorrectCNN(nn.Module):
    def __init__(self):
        super(CorrectCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc = nn.Linear(16 * 14 * 14, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool(x)
        x = x.view(x.size(0), -1)  # Correct flattening
        x = self.fc(x)
        return x
📊

Quick Reference

Remember these key points when building CNNs in PyTorch:

  • Use nn.Conv2d for convolution layers.
  • Apply activation functions like nn.ReLU after convolutions.
  • Use pooling layers like nn.MaxPool2d to reduce spatial size.
  • Flatten with x.view(x.size(0), -1) before fully connected layers.
  • Define the forward method to specify data flow.

Key Takeaways

Define CNNs by subclassing torch.nn.Module and implementing the forward method.
Use nn.Conv2d, nn.ReLU, and nn.MaxPool2d layers to build convolutional blocks.
Always flatten the output correctly before feeding it to fully connected layers.
Ensure input tensors have shape (batch_size, channels, height, width).
Test your model with dummy data to verify output shapes and flow.