This code builds a simple Vision Transformer model. It splits images into patches, adds position info, processes patches with transformer blocks, and outputs class scores. We run it on two random images and print the results.
import torch
import torch.nn as nn
import torch.nn.functional as F
class PatchEmbedding(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
self.patch_size = patch_size
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
x = self.proj(x) # (B, embed_dim, H/patch, W/patch)
x = x.flatten(2) # (B, embed_dim, N)
x = x.transpose(1, 2) # (B, N, embed_dim)
return x
class TransformerBlock(nn.Module):
def __init__(self, dim, heads, mlp_dim):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=heads)
self.norm2 = nn.LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_dim),
nn.GELU(),
nn.Linear(mlp_dim, dim)
)
def forward(self, x):
x2 = self.norm1(x)
attn_output, _ = self.attn(x2, x2, x2)
x = x + attn_output
x2 = self.norm2(x)
x = x + self.mlp(x2)
return x
class VisionTransformer(nn.Module):
def __init__(self, image_size=224, patch_size=16, num_classes=10, dim=768, depth=6, heads=8, mlp_dim=2048):
super().__init__()
self.patch_embed = PatchEmbedding(image_size, patch_size, 3, dim)
num_patches = (image_size // patch_size) ** 2
self.cls_token = nn.Parameter(torch.zeros(1, 1, dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, dim))
self.transformer_blocks = nn.ModuleList([
TransformerBlock(dim, heads, mlp_dim) for _ in range(depth)
])
self.norm = nn.LayerNorm(dim)
self.head = nn.Linear(dim, num_classes)
def forward(self, x):
B = x.shape[0]
x = self.patch_embed(x) # (B, N, dim)
cls_tokens = self.cls_token.expand(B, -1, -1) # (B, 1, dim)
x = torch.cat((cls_tokens, x), dim=1) # (B, N+1, dim)
x = x + self.pos_embed
x = x.transpose(0, 1) # Transformer expects (N+1, B, dim)
for block in self.transformer_blocks:
x = block(x)
x = x.transpose(0, 1) # (B, N+1, dim)
x = self.norm(x)
cls_output = x[:, 0] # (B, dim)
out = self.head(cls_output) # (B, num_classes)
return out
# Create dummy data: batch of 2 images, 3 channels, 224x224
images = torch.randn(2, 3, 224, 224)
# Create model
model = VisionTransformer()
# Forward pass
outputs = model(images)
# Print output shape and values
print('Output shape:', outputs.shape)
print('Output values:', outputs)