This code builds a simple Transformer encoder model that takes a sequence of numbers representing words and predicts the next words. It prints the shape of the output and the probabilities for the first word in the sequence.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SimpleTransformer(nn.Module):
def __init__(self, vocab_size, embed_size, num_heads, hidden_dim, num_layers):
super().__init__()
self.d_model = embed_size
self.embedding = nn.Embedding(vocab_size, embed_size)
# Positional encoding
max_len = 5000
pe = torch.zeros(max_len, self.d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, self.d_model, 2, dtype=torch.float) * (-math.log(10000.0) / self.d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pos_encoder', pe.unsqueeze(1))
encoder_layer = nn.TransformerEncoderLayer(d_model=self.d_model, nhead=num_heads, dim_feedforward=hidden_dim)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.fc_out = nn.Linear(self.d_model, vocab_size)
def forward(self, src):
embedded = self.embedding(src) * math.sqrt(self.d_model)
embedded = embedded + self.pos_encoder[:embedded.size(0)] # (seq_len, batch, embed_size)
encoded = self.encoder(embedded) # (seq_len, batch, embed_size)
output = self.fc_out(encoded) # (seq_len, batch, vocab_size)
return output
# Sample data: batch size 1, sequence length 5
vocab_size = 10
embed_size = 8
num_heads = 2
hidden_dim = 16
num_layers = 1
model = SimpleTransformer(vocab_size, embed_size, num_heads, hidden_dim, num_layers)
# Input sequence of token ids (seq_len=5, batch=1)
src = torch.tensor([[1, 2, 3, 4, 5]]).T # shape (5,1)
output = model(src) # shape (5,1,vocab_size)
# Convert output logits to probabilities
probs = F.softmax(output, dim=-1)
# Print shape and first token probabilities
print(f"Output shape: {output.shape}")
print(f"Probabilities for first token:\n{probs[0,0].detach().numpy()}")