How to Build a CNN in PyTorch: Simple Guide with Example
To build a CNN in PyTorch, define a class inheriting from
torch.nn.Module and create convolutional layers using nn.Conv2d, activation functions like nn.ReLU, and pooling layers such as nn.MaxPool2d. Then implement the forward method to specify how data flows through these layers.Syntax
Here is the basic syntax to define a CNN in PyTorch:
- Import torch and nn: Use
import torchandimport torch.nn as nn. - Create a class: Inherit from
nn.Moduleto define your CNN. - Define layers: Use
nn.Conv2dfor convolution,nn.ReLUfor activation, andnn.MaxPool2dfor pooling. - Forward method: Define how input data passes through layers.
python
import torch import torch.nn as nn class SimpleCNN(nn.Module): def __init__(self): super(SimpleCNN, self).__init__() self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, padding=1) self.relu = nn.ReLU() self.pool = nn.MaxPool2d(kernel_size=2, stride=2) self.fc = nn.Linear(16 * 14 * 14, 10) # assuming input 28x28 def forward(self, x): x = self.conv1(x) x = self.relu(x) x = self.pool(x) x = x.view(x.size(0), -1) # flatten x = self.fc(x) return x
Example
This example shows a simple CNN for classifying 28x28 grayscale images (like MNIST digits). It has one convolutional layer, ReLU activation, max pooling, and a fully connected layer for output.
python
import torch import torch.nn as nn class SimpleCNN(nn.Module): def __init__(self): super(SimpleCNN, self).__init__() self.conv1 = nn.Conv2d(1, 16, 3, padding=1) self.relu = nn.ReLU() self.pool = nn.MaxPool2d(2, 2) self.fc = nn.Linear(16 * 14 * 14, 10) def forward(self, x): x = self.conv1(x) x = self.relu(x) x = self.pool(x) x = x.view(x.size(0), -1) x = self.fc(x) return x # Create model instance model = SimpleCNN() # Create dummy input: batch size 1, 1 channel, 28x28 image input_tensor = torch.randn(1, 1, 28, 28) # Forward pass output = model(input_tensor) print(output)
Output
tensor([[ 0.0951, 0.0347, -0.0345, 0.0423, 0.0427, 0.0272, 0.0319, 0.0123, 0.0227, -0.0027]], grad_fn=<AddmmBackward0>)
Common Pitfalls
- Wrong input shape: CNN expects input as (batch_size, channels, height, width). Forgetting channels or batch dimension causes errors.
- Flattening incorrectly: Use
x.view(x.size(0), -1)to flatten before fully connected layers. - Missing activation: Without activation functions like ReLU, the model can't learn complex patterns.
- Output size mismatch: Make sure the input size to the fully connected layer matches the flattened feature size after convolutions and pooling.
python
import torch import torch.nn as nn # Wrong flattening example class WrongCNN(nn.Module): def __init__(self): super(WrongCNN, self).__init__() self.conv1 = nn.Conv2d(1, 16, 3, padding=1) self.pool = nn.MaxPool2d(2, 2) self.fc = nn.Linear(16 * 14 * 14, 10) def forward(self, x): x = self.conv1(x) x = self.pool(x) # Incorrect flattening - missing batch size x = x.view(-1) # This will cause error x = self.fc(x) return x # Correct flattening class CorrectCNN(nn.Module): def __init__(self): super(CorrectCNN, self).__init__() self.conv1 = nn.Conv2d(1, 16, 3, padding=1) self.pool = nn.MaxPool2d(2, 2) self.fc = nn.Linear(16 * 14 * 14, 10) def forward(self, x): x = self.conv1(x) x = self.pool(x) x = x.view(x.size(0), -1) # Correct flattening x = self.fc(x) return x
Quick Reference
Remember these key points when building CNNs in PyTorch:
- Use
nn.Conv2dfor convolution layers. - Apply activation functions like
nn.ReLUafter convolutions. - Use pooling layers like
nn.MaxPool2dto reduce spatial size. - Flatten with
x.view(x.size(0), -1)before fully connected layers. - Define the
forwardmethod to specify data flow.
Key Takeaways
Define CNNs by subclassing torch.nn.Module and implementing the forward method.
Use nn.Conv2d, nn.ReLU, and nn.MaxPool2d layers to build convolutional blocks.
Always flatten the output correctly before feeding it to fully connected layers.
Ensure input tensors have shape (batch_size, channels, height, width).
Test your model with dummy data to verify output shapes and flow.