Bird
Raised Fist0
Computer Visionml~5 mins

Vision Transformer (ViT) in Computer Vision

Choose your learning style10 modes available

Start learning this pattern below

Jump into concepts and practice - no test required

or
Recommended
Test this pattern10 questions across easy, medium, and hard to know if this pattern is strong
Introduction
Vision Transformer (ViT) helps computers understand images by looking at small parts of the image and learning patterns, just like how we focus on pieces to see the whole picture.
When you want to classify images into categories, like sorting photos of cats and dogs.
When you need to detect objects in pictures, such as finding cars in street photos.
When you want to improve image recognition accuracy using a new method different from traditional tools.
When working with large image datasets and want to use a model that learns relationships between image parts.
When experimenting with modern AI models that use attention to focus on important image details.
Syntax
Computer Vision
class VisionTransformer(nn.Module):
    def __init__(self, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim):
        super().__init__()
        # Split image into patches
        # Add position info
        # Use transformer layers
        # Classify with a final layer

    def forward(self, x):
        # Process input through patches and transformer
        # Return class predictions
ViT splits images into small patches and treats them like words in a sentence.
It uses transformer layers to learn how patches relate to each other.
Examples
Create a Vision Transformer for 224x224 images split into 16x16 patches, with 10 output classes.
Computer Vision
vit = VisionTransformer(
    image_size=224,
    patch_size=16,
    num_classes=10,
    dim=512,
    depth=6,
    heads=8,
    mlp_dim=1024
)
Run the model on an input image batch and print the output shape showing class scores.
Computer Vision
output = vit(input_tensor)
print(output.shape)
Sample Model
This code builds a simple Vision Transformer model. It splits images into patches, adds position info, processes patches with transformer blocks, and outputs class scores. We run it on two random images and print the results.
Computer Vision
import torch
import torch.nn as nn
import torch.nn.functional as F

class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        self.patch_size = patch_size
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)  # (B, embed_dim, H/patch, W/patch)
        x = x.flatten(2)  # (B, embed_dim, N)
        x = x.transpose(1, 2)  # (B, N, embed_dim)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, dim, heads, mlp_dim):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=heads)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            nn.GELU(),
            nn.Linear(mlp_dim, dim)
        )

    def forward(self, x):
        x2 = self.norm1(x)
        attn_output, _ = self.attn(x2, x2, x2)
        x = x + attn_output
        x2 = self.norm2(x)
        x = x + self.mlp(x2)
        return x

class VisionTransformer(nn.Module):
    def __init__(self, image_size=224, patch_size=16, num_classes=10, dim=768, depth=6, heads=8, mlp_dim=2048):
        super().__init__()
        self.patch_embed = PatchEmbedding(image_size, patch_size, 3, dim)
        num_patches = (image_size // patch_size) ** 2
        self.cls_token = nn.Parameter(torch.zeros(1, 1, dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, dim))
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(dim, heads, mlp_dim) for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(dim)
        self.head = nn.Linear(dim, num_classes)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)  # (B, N, dim)
        cls_tokens = self.cls_token.expand(B, -1, -1)  # (B, 1, dim)
        x = torch.cat((cls_tokens, x), dim=1)  # (B, N+1, dim)
        x = x + self.pos_embed
        x = x.transpose(0, 1)  # Transformer expects (N+1, B, dim)
        for block in self.transformer_blocks:
            x = block(x)
        x = x.transpose(0, 1)  # (B, N+1, dim)
        x = self.norm(x)
        cls_output = x[:, 0]  # (B, dim)
        out = self.head(cls_output)  # (B, num_classes)
        return out

# Create dummy data: batch of 2 images, 3 channels, 224x224
images = torch.randn(2, 3, 224, 224)

# Create model
model = VisionTransformer()

# Forward pass
outputs = model(images)

# Print output shape and values
print('Output shape:', outputs.shape)
print('Output values:', outputs)
OutputSuccess
Important Notes
ViT needs images to be split into fixed-size patches before processing.
The class token helps the model summarize the whole image for classification.
Position embeddings tell the model where each patch is in the image.
Summary
Vision Transformer splits images into patches and uses transformer layers to learn patterns.
It uses a special class token to gather information for image classification.
ViT is a modern way to understand images using attention instead of traditional convolution.

Practice

(1/5)
1. What is the main purpose of splitting an image into patches in a Vision Transformer (ViT)?
easy
A. To reduce the image size by cropping
B. To convert the image into smaller parts that the transformer can process as tokens
C. To apply convolution filters on each patch separately
D. To increase the image resolution for better detail

Solution

  1. Step 1: Understand ViT input processing

    ViT splits images into fixed-size patches to treat each patch like a word token in language models.
  2. Step 2: Purpose of patch splitting

    This allows the transformer to process image patches as a sequence, enabling attention mechanisms to learn relationships.
  3. Final Answer:

    To convert the image into smaller parts that the transformer can process as tokens -> Option B
  4. Quick Check:

    Image patches = tokens for transformer [OK]
Hint: Think of patches as words in a sentence for the transformer [OK]
Common Mistakes:
  • Confusing patch splitting with image resizing
  • Thinking patches are processed by convolution
  • Assuming patches increase image resolution
2. Which of the following is the correct way to add a class token to the patch embeddings in ViT using Python-like pseudocode?
easy
A. patches = torch.cat([class_token, patches], dim=1)
B. patches = torch.cat([patches, class_token], dim=1)
C. patches = torch.cat([patches, class_token], dim=0)
D. patches = torch.cat([class_token, patches], dim=0)

Solution

  1. Step 1: Understand tensor concatenation dimension

    Patch embeddings are sequences along dimension 1 (batch, seq, embed); class token must be prepended along this dimension.
  2. Step 2: Correct concatenation syntax

    Using torch.cat with dim=1 adds class_token at the start of the sequence correctly.
  3. Final Answer:

    patches = torch.cat([class_token, patches], dim=1) -> Option A
  4. Quick Check:

    Class token prepended along sequence dim = patches = torch.cat([class_token, patches], dim=1) [OK]
Hint: Class token goes first, concat along sequence dimension (dim=1) [OK]
Common Mistakes:
  • Concatenating along wrong dimension (dim=0)
  • Appending class token at the end instead of start
  • Mixing order of tensors in concat
3. Given the following simplified ViT patch embedding code, what is the shape of patch_embeddings after processing a batch of 8 images of size 32x32 with patch size 8 and embedding dimension 64?
patch_size = 8
embedding_dim = 64
batch_size = 8
image_size = 32
num_patches = (image_size // patch_size) ** 2
patch_embeddings = torch.randn(batch_size, num_patches, embedding_dim)
medium
A. (16, 8, 64)
B. (8, 64, 16)
C. (8, 8, 64)
D. (8, 16, 64)

Solution

  1. Step 1: Calculate number of patches

    Number of patches = (32 / 8)^2 = 4^2 = 16 patches per image.
  2. Step 2: Determine patch_embeddings shape

    Shape is (batch_size, num_patches, embedding_dim) = (8, 16, 64).
  3. Final Answer:

    (8, 16, 64) -> Option D
  4. Quick Check:

    Batch=8, patches=16, embed=64 [OK]
Hint: Calculate patches as (image/patch)^2, then batch x patches x embed [OK]
Common Mistakes:
  • Mixing embedding dimension and patch count order
  • Calculating patches incorrectly
  • Confusing batch size with patch count
4. You have this ViT code snippet that throws an error:
class_token = torch.randn(1, 1, 64)
patches = torch.randn(8, 16, 64)
input_seq = torch.cat([class_token, patches], dim=1)

What is the cause of the error?
medium
A. Embedding dimensions do not match
B. Wrong concatenation dimension; should be dim=0
C. class_token shape should be (8, 1, 64) to match batch size
D. Dimension mismatch because class_token sequence size is 1 but patches sequence size is 16

Solution

  1. Step 1: Check batch size compatibility

    class_token has batch size 1, patches have batch size 8; they must match for concatenation.
  2. Step 2: Fix class_token shape

    class_token should be repeated or created with shape (8, 1, 64) to match patches batch size.
  3. Final Answer:

    class_token shape should be (8, 1, 64) to match batch size -> Option C
  4. Quick Check:

    Batch sizes must match for concat [OK]
Hint: Match batch sizes before concatenating tensors [OK]
Common Mistakes:
  • Ignoring batch size mismatch
  • Changing wrong concat dimension
  • Assuming embedding dims cause error
5. In a Vision Transformer model, why is the class token important for image classification tasks?
hard
A. It aggregates information from all patches via attention to produce a final image representation
B. It stores the positional information of patches
C. It applies convolution to patches before transformer layers
D. It normalizes the patch embeddings before feeding to the transformer

Solution

  1. Step 1: Understand class token role

    The class token is a special token that attends to all patch tokens and gathers their information.
  2. Step 2: Use in classification

    After transformer layers, the class token embedding is used as the image's summary representation for classification.
  3. Final Answer:

    It aggregates information from all patches via attention to produce a final image representation -> Option A
  4. Quick Check:

    Class token = image summary for classification [OK]
Hint: Class token collects info from patches for final decision [OK]
Common Mistakes:
  • Confusing class token with positional encoding
  • Thinking class token applies convolution
  • Assuming class token normalizes embeddings