Introduction
Vision Transformer (ViT) helps computers understand images by looking at small parts of the image and learning patterns, just like how we focus on pieces to see the whole picture.
Jump into concepts and practice - no test required
class VisionTransformer(nn.Module): def __init__(self, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim): super().__init__() # Split image into patches # Add position info # Use transformer layers # Classify with a final layer def forward(self, x): # Process input through patches and transformer # Return class predictions
vit = VisionTransformer(
image_size=224,
patch_size=16,
num_classes=10,
dim=512,
depth=6,
heads=8,
mlp_dim=1024
)output = vit(input_tensor)
print(output.shape)import torch import torch.nn as nn import torch.nn.functional as F class PatchEmbedding(nn.Module): def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): super().__init__() self.patch_size = patch_size self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x): x = self.proj(x) # (B, embed_dim, H/patch, W/patch) x = x.flatten(2) # (B, embed_dim, N) x = x.transpose(1, 2) # (B, N, embed_dim) return x class TransformerBlock(nn.Module): def __init__(self, dim, heads, mlp_dim): super().__init__() self.norm1 = nn.LayerNorm(dim) self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=heads) self.norm2 = nn.LayerNorm(dim) self.mlp = nn.Sequential( nn.Linear(dim, mlp_dim), nn.GELU(), nn.Linear(mlp_dim, dim) ) def forward(self, x): x2 = self.norm1(x) attn_output, _ = self.attn(x2, x2, x2) x = x + attn_output x2 = self.norm2(x) x = x + self.mlp(x2) return x class VisionTransformer(nn.Module): def __init__(self, image_size=224, patch_size=16, num_classes=10, dim=768, depth=6, heads=8, mlp_dim=2048): super().__init__() self.patch_embed = PatchEmbedding(image_size, patch_size, 3, dim) num_patches = (image_size // patch_size) ** 2 self.cls_token = nn.Parameter(torch.zeros(1, 1, dim)) self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, dim)) self.transformer_blocks = nn.ModuleList([ TransformerBlock(dim, heads, mlp_dim) for _ in range(depth) ]) self.norm = nn.LayerNorm(dim) self.head = nn.Linear(dim, num_classes) def forward(self, x): B = x.shape[0] x = self.patch_embed(x) # (B, N, dim) cls_tokens = self.cls_token.expand(B, -1, -1) # (B, 1, dim) x = torch.cat((cls_tokens, x), dim=1) # (B, N+1, dim) x = x + self.pos_embed x = x.transpose(0, 1) # Transformer expects (N+1, B, dim) for block in self.transformer_blocks: x = block(x) x = x.transpose(0, 1) # (B, N+1, dim) x = self.norm(x) cls_output = x[:, 0] # (B, dim) out = self.head(cls_output) # (B, num_classes) return out # Create dummy data: batch of 2 images, 3 channels, 224x224 images = torch.randn(2, 3, 224, 224) # Create model model = VisionTransformer() # Forward pass outputs = model(images) # Print output shape and values print('Output shape:', outputs.shape) print('Output values:', outputs)
patch_embeddings after processing a batch of 8 images of size 32x32 with patch size 8 and embedding dimension 64?patch_size = 8 embedding_dim = 64 batch_size = 8 image_size = 32 num_patches = (image_size // patch_size) ** 2 patch_embeddings = torch.randn(batch_size, num_patches, embedding_dim)
class_token = torch.randn(1, 1, 64) patches = torch.randn(8, 16, 64) input_seq = torch.cat([class_token, patches], dim=1)