0
0
PytorchHow-ToBeginner · 4 min read

How to Do Image Segmentation with PyTorch: Simple Guide

To do image segmentation in PyTorch, you typically use a convolutional neural network like UNet or DeepLabV3 that outputs pixel-wise class predictions. You prepare your dataset with images and masks, define the model, loss function, and optimizer, then train the model to predict segmentation masks from input images.
📐

Syntax

Image segmentation in PyTorch involves these main steps:

  • Dataset: Load images and their corresponding masks.
  • Model: Use a segmentation model like DeepLabV3 from torchvision.models.segmentation.
  • Loss: Use a pixel-wise loss like CrossEntropyLoss.
  • Training: Forward pass input images, compute loss with masks, backpropagate, and update weights.
python
import torch
from torchvision.models.segmentation import deeplabv3_resnet50

# Load pretrained DeepLabV3 model for segmentation
model = deeplabv3_resnet50(pretrained=True)
model.eval()  # Set to evaluation mode

# Input tensor shape: (batch_size, 3, height, width)
input_tensor = torch.randn(1, 3, 224, 224)

# Forward pass to get output
output = model(input_tensor)['out']

# Output shape: (batch_size, num_classes, height, width)
print(output.shape)
Output
torch.Size([1, 21, 224, 224])
💻

Example

This example shows how to load a pretrained DeepLabV3 model, run a dummy image through it, and get the predicted segmentation mask.

python
import torch
from torchvision.models.segmentation import deeplabv3_resnet50
from torchvision.transforms import Compose, ToTensor, Normalize
from PIL import Image
import numpy as np

# Load pretrained DeepLabV3 model
model = deeplabv3_resnet50(pretrained=True)
model.eval()

# Dummy image creation (random noise image)
image = Image.fromarray((np.random.rand(224,224,3)*255).astype(np.uint8))

# Preprocessing transforms
transform = Compose([
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
input_tensor = transform(image).unsqueeze(0)  # Add batch dimension

# Forward pass
with torch.no_grad():
    output = model(input_tensor)['out']

# Get predicted class for each pixel
pred_mask = output.argmax(1).squeeze().cpu().numpy()

print(f"Predicted mask shape: {pred_mask.shape}")
print(f"Unique classes in mask: {np.unique(pred_mask)}")
Output
Predicted mask shape: (224, 224) Unique classes in mask: [ 0 15 16 17 18 19 20]
⚠️

Common Pitfalls

1. Not matching input and mask sizes: The input image and mask must have the same height and width for pixel-wise loss.

2. Using wrong loss function: Use CrossEntropyLoss for multi-class segmentation masks, not MSE.

3. Forgetting to set model to train/eval mode: Use model.train() during training and model.eval() during evaluation to handle layers like dropout and batchnorm correctly.

4. Not normalizing input images: Pretrained models expect inputs normalized with ImageNet mean and std.

python
import torch
import torch.nn as nn

# Wrong: Using MSELoss for segmentation
loss_fn_wrong = nn.MSELoss()

# Right: Use CrossEntropyLoss for pixel-wise classification
loss_fn_right = nn.CrossEntropyLoss()

print(f"Wrong loss function: {loss_fn_wrong}")
print(f"Right loss function: {loss_fn_right}")
Output
Wrong loss function: MSELoss() Right loss function: CrossEntropyLoss()
📊

Quick Reference

  • Use torchvision.models.segmentation for pretrained segmentation models.
  • Input images must be normalized with ImageNet stats.
  • Output shape is (batch_size, num_classes, height, width).
  • Use CrossEntropyLoss for training with class masks.
  • Set model to train() or eval() mode appropriately.

Key Takeaways

Use pretrained segmentation models like DeepLabV3 from torchvision for easy setup.
Normalize input images with ImageNet mean and std before feeding to the model.
Use CrossEntropyLoss for pixel-wise classification during training.
Ensure input images and masks have matching spatial dimensions.
Switch model modes between train() and eval() to get correct behavior.