import torch
import torch.nn as nn
import torch.optim as optim
import math
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # shape (1, max_len, d_model)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :x.size(1)]
return x
class TransformerModel(nn.Module):
def __init__(self, vocab_size, d_model=64, nhead=4, num_layers=2, dim_feedforward=128, dropout=0.1):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoder = PositionalEncoding(d_model)
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.decoder = nn.Linear(d_model, vocab_size)
self.d_model = d_model
def forward(self, src):
src = self.embedding(src) * math.sqrt(self.d_model)
src = self.pos_encoder(src)
src = src.transpose(0, 1) # Transformer expects seq_len, batch, feature
output = self.transformer_encoder(src)
output = output.transpose(0, 1) # batch, seq_len, feature
output = self.decoder(output)
return output
# Synthetic dataset: simple sequence copying task
vocab_size = 20
seq_len = 10
batch_size = 32
num_batches = 100
def generate_batch():
data = torch.randint(1, vocab_size, (batch_size, seq_len))
target = data.clone()
return data, target
model = TransformerModel(vocab_size)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
for epoch in range(10):
total_loss = 0
correct = 0
total = 0
model.train()
for _ in range(num_batches):
data, target = generate_batch()
optimizer.zero_grad()
output = model(data)
# output shape: batch, seq_len, vocab_size
loss = criterion(output.view(-1, vocab_size), target.view(-1))
loss.backward()
optimizer.step()
total_loss += loss.item()
pred = output.argmax(dim=2)
correct += (pred == target).sum().item()
total += target.numel()
train_loss = total_loss / num_batches
train_acc = correct / total * 100
print(f"Epoch {epoch+1}: Loss={train_loss:.4f}, Accuracy={train_acc:.2f}%")