0
0
Computer Visionml~5 mins

Vision Transformer (ViT) in Computer Vision

Choose your learning style9 modes available
Introduction
Vision Transformer (ViT) helps computers understand images by looking at small parts of the image and learning patterns, just like how we focus on pieces to see the whole picture.
When you want to classify images into categories, like sorting photos of cats and dogs.
When you need to detect objects in pictures, such as finding cars in street photos.
When you want to improve image recognition accuracy using a new method different from traditional tools.
When working with large image datasets and want to use a model that learns relationships between image parts.
When experimenting with modern AI models that use attention to focus on important image details.
Syntax
Computer Vision
class VisionTransformer(nn.Module):
    def __init__(self, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim):
        super().__init__()
        # Split image into patches
        # Add position info
        # Use transformer layers
        # Classify with a final layer

    def forward(self, x):
        # Process input through patches and transformer
        # Return class predictions
ViT splits images into small patches and treats them like words in a sentence.
It uses transformer layers to learn how patches relate to each other.
Examples
Create a Vision Transformer for 224x224 images split into 16x16 patches, with 10 output classes.
Computer Vision
vit = VisionTransformer(
    image_size=224,
    patch_size=16,
    num_classes=10,
    dim=512,
    depth=6,
    heads=8,
    mlp_dim=1024
)
Run the model on an input image batch and print the output shape showing class scores.
Computer Vision
output = vit(input_tensor)
print(output.shape)
Sample Model
This code builds a simple Vision Transformer model. It splits images into patches, adds position info, processes patches with transformer blocks, and outputs class scores. We run it on two random images and print the results.
Computer Vision
import torch
import torch.nn as nn
import torch.nn.functional as F

class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        self.patch_size = patch_size
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)  # (B, embed_dim, H/patch, W/patch)
        x = x.flatten(2)  # (B, embed_dim, N)
        x = x.transpose(1, 2)  # (B, N, embed_dim)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, dim, heads, mlp_dim):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=heads)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            nn.GELU(),
            nn.Linear(mlp_dim, dim)
        )

    def forward(self, x):
        x2 = self.norm1(x)
        attn_output, _ = self.attn(x2, x2, x2)
        x = x + attn_output
        x2 = self.norm2(x)
        x = x + self.mlp(x2)
        return x

class VisionTransformer(nn.Module):
    def __init__(self, image_size=224, patch_size=16, num_classes=10, dim=768, depth=6, heads=8, mlp_dim=2048):
        super().__init__()
        self.patch_embed = PatchEmbedding(image_size, patch_size, 3, dim)
        num_patches = (image_size // patch_size) ** 2
        self.cls_token = nn.Parameter(torch.zeros(1, 1, dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, dim))
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(dim, heads, mlp_dim) for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(dim)
        self.head = nn.Linear(dim, num_classes)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)  # (B, N, dim)
        cls_tokens = self.cls_token.expand(B, -1, -1)  # (B, 1, dim)
        x = torch.cat((cls_tokens, x), dim=1)  # (B, N+1, dim)
        x = x + self.pos_embed
        x = x.transpose(0, 1)  # Transformer expects (N+1, B, dim)
        for block in self.transformer_blocks:
            x = block(x)
        x = x.transpose(0, 1)  # (B, N+1, dim)
        x = self.norm(x)
        cls_output = x[:, 0]  # (B, dim)
        out = self.head(cls_output)  # (B, num_classes)
        return out

# Create dummy data: batch of 2 images, 3 channels, 224x224
images = torch.randn(2, 3, 224, 224)

# Create model
model = VisionTransformer()

# Forward pass
outputs = model(images)

# Print output shape and values
print('Output shape:', outputs.shape)
print('Output values:', outputs)
OutputSuccess
Important Notes
ViT needs images to be split into fixed-size patches before processing.
The class token helps the model summarize the whole image for classification.
Position embeddings tell the model where each patch is in the image.
Summary
Vision Transformer splits images into patches and uses transformer layers to learn patterns.
It uses a special class token to gather information for image classification.
ViT is a modern way to understand images using attention instead of traditional convolution.