0
0
PyTorchml~5 mins

CNN architecture for image classification in PyTorch

Choose your learning style9 modes available
Introduction

A CNN helps a computer learn to recognize pictures by looking at small parts step-by-step.

When you want a computer to tell if a photo has a cat or a dog.
When sorting pictures into groups like cars, trees, or people.
When you want to find objects in photos, like faces or signs.
When you want to improve photo search by recognizing what's inside.
When building apps that need to understand images, like photo filters.
Syntax
PyTorch
import torch
import torch.nn as nn

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.fc1 = nn.Linear(32 * 8 * 8, 10)  # assuming input images are 32x32

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)  # flatten
        x = self.fc1(x)
        return x

Input images are expected to have 3 color channels (RGB) and size 32x32 pixels.

Output layer size (10) matches the number of classes to predict.

Examples
First convolution layer with 16 filters and 3x3 size, followed by 2x2 max pooling to reduce image size.
PyTorch
self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
Flatten the 3D feature maps into 1D vector before feeding into the fully connected layer.
PyTorch
x = x.view(x.size(0), -1)
Fully connected layer that outputs 10 class scores from the flattened features.
PyTorch
self.fc1 = nn.Linear(32 * 8 * 8, 10)
Sample Model

This code trains the CNN on one batch of CIFAR10 images and prints the loss and predicted classes for the first 5 images.

PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Define CNN model
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.fc1 = nn.Linear(32 * 8 * 8, 10)

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x

# Prepare data (CIFAR10, small image dataset)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)

# Initialize model, loss, optimizer
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train for 1 epoch
model.train()
for images, labels in trainloader:
    optimizer.zero_grad()
    outputs = model(images)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    break  # train on only 1 batch for demo

# Print loss and prediction example
print(f"Loss after 1 batch: {loss.item():.4f}")
_, predicted = torch.max(outputs, 1)
print(f"Predicted classes for first 5 images: {predicted[:5].tolist()}")
OutputSuccess
Important Notes

Use small batch sizes when starting to keep training fast and simple.

ReLU helps the model learn by adding non-linearity.

Pooling reduces image size and helps the model focus on important features.

Summary

CNNs look at images in small parts to learn patterns.

Convolution layers find features, pooling layers shrink images, and fully connected layers decide the class.

Training adjusts the CNN to recognize images correctly.