0
0
PytorchHow-ToBeginner · 4 min read

How to Build an Image Classifier with PyTorch

To build an image classifier in PyTorch, define a neural network model using torch.nn.Module, prepare your image dataset with torchvision.transforms, and train the model using a loss function and optimizer. Finally, evaluate the model on test images to get predictions.
📐

Syntax

Here is the basic syntax pattern to build an image classifier in PyTorch:

  • class ModelName(torch.nn.Module): defines the model architecture.
  • def __init__(self): sets up layers.
  • def forward(self, x): defines how data flows through the model.
  • torchvision.datasets.ImageFolder loads image data.
  • torch.utils.data.DataLoader creates batches for training.
  • torch.optim.Optimizer updates model weights.
  • loss_fn computes error between prediction and true label.
python
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

# Define the model
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, 1)
        self.fc1 = nn.Linear(16*30*30, 10)  # assuming input images 32x32

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = x.view(x.size(0), -1)  # flatten
        x = self.fc1(x)
        return x

# Prepare data transforms
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor()
])

# Load dataset
train_dataset = datasets.FakeData(transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

# Initialize model, loss, optimizer
model = SimpleCNN()
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())
💻

Example

This example shows a complete runnable PyTorch script that trains a simple CNN on fake image data for 1 epoch and prints training loss.

python
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, 1)
        self.fc1 = nn.Linear(16*30*30, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x

transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor()
])

train_dataset = datasets.FakeData(transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

model = SimpleCNN()
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

model.train()
for batch_idx, (data, target) in enumerate(train_loader):
    optimizer.zero_grad()
    output = model(data)
    loss = loss_fn(output, target)
    loss.backward()
    optimizer.step()
    if batch_idx == 0:
        print(f'Training loss: {loss.item():.4f}')
        break
Output
Training loss: 2.3025
⚠️

Common Pitfalls

Common mistakes when building image classifiers in PyTorch include:

  • Not flattening the tensor before the fully connected layer, causing shape errors.
  • Forgetting to call optimizer.zero_grad() before backpropagation, which accumulates gradients incorrectly.
  • Using the wrong loss function for classification (use CrossEntropyLoss for multi-class).
  • Not normalizing or resizing images properly, which hurts model performance.
python
import torch
import torch.nn as nn

# Wrong: missing flatten before fc layer
class WrongModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 16, 3)
        self.fc = nn.Linear(16*30*30, 10)

    def forward(self, x):
        x = torch.relu(self.conv(x))
        # Missing flatten here causes error
        x = self.fc(x)
        return x

# Right: flatten before fc
class RightModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 16, 3)
        self.fc = nn.Linear(16*30*30, 10)

    def forward(self, x):
        x = torch.relu(self.conv(x))
        x = x.view(x.size(0), -1)  # flatten
        x = self.fc(x)
        return x
📊

Quick Reference

Summary tips for building image classifiers in PyTorch:

  • Use torchvision.transforms to resize and normalize images.
  • Define model by subclassing torch.nn.Module and implementing forward.
  • Use DataLoader for batching and shuffling data.
  • Choose CrossEntropyLoss for multi-class classification.
  • Call optimizer.zero_grad() before loss.backward() to reset gradients.

Key Takeaways

Define your model by subclassing torch.nn.Module and implementing the forward method.
Use torchvision transforms to prepare images with resizing and normalization.
Train with DataLoader batches, CrossEntropyLoss, and an optimizer like Adam.
Always zero gradients before backpropagation to avoid gradient accumulation.
Flatten convolution outputs before feeding into fully connected layers.