import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, embed_dim, num_heads, dropout=0.1):
super().__init__()
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.q_linear = nn.Linear(embed_dim, embed_dim)
self.k_linear = nn.Linear(embed_dim, embed_dim)
self.v_linear = nn.Linear(embed_dim, embed_dim)
self.out_linear = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
self.scale = self.head_dim ** 0.5
def forward(self, x):
batch_size, seq_len, embed_dim = x.size()
# Linear projections
Q = self.q_linear(x) # (batch_size, seq_len, embed_dim)
K = self.k_linear(x)
V = self.v_linear(x)
# Split into heads
Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (batch_size, num_heads, seq_len, head_dim)
K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# Scaled dot-product attention
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale # (batch_size, num_heads, seq_len, seq_len)
attn = F.softmax(scores, dim=-1)
attn = self.dropout(attn) # Apply dropout to attention weights
out = torch.matmul(attn, V) # (batch_size, num_heads, seq_len, head_dim)
# Concatenate heads
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)
# Final linear layer
out = self.out_linear(out)
out = self.dropout(out) # Apply dropout after output projection
return out
# Simple transformer block using the improved MultiHeadAttention
class TransformerBlock(nn.Module):
def __init__(self, embed_dim, num_heads, dropout=0.1):
super().__init__()
self.mha = MultiHeadAttention(embed_dim, num_heads, dropout)
self.norm1 = nn.LayerNorm(embed_dim)
self.ff = nn.Sequential(
nn.Linear(embed_dim, embed_dim * 4),
nn.ReLU(),
nn.Linear(embed_dim * 4, embed_dim),
nn.Dropout(dropout)
)
self.norm2 = nn.LayerNorm(embed_dim)
def forward(self, x):
attn_out = self.mha(x)
x = self.norm1(x + attn_out)
ff_out = self.ff(x)
x = self.norm2(x + ff_out)
return x
# Dummy dataset and training loop for demonstration
import torch.optim as optim
def train_model():
torch.manual_seed(42)
embed_dim = 64
num_heads = 4
dropout = 0.2
model = TransformerBlock(embed_dim, num_heads, dropout)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
# Dummy data: batch_size=32, seq_len=10, embed_dim=64
X_train = torch.randn(100, 10, embed_dim)
y_train = torch.randint(0, 2, (100,))
X_val = torch.randn(30, 10, embed_dim)
y_val = torch.randint(0, 2, (30,))
for epoch in range(30):
model.train()
optimizer.zero_grad()
out = model(X_train) # (100, 10, 64)
out = out.mean(dim=1) # simple pooling
loss = criterion(out, y_train)
loss.backward()
optimizer.step()
model.eval()
with torch.no_grad():
val_out = model(X_val)
val_out = val_out.mean(dim=1)
val_loss = criterion(val_out, y_val)
if epoch % 5 == 0 or epoch == 29:
train_acc = (out.argmax(dim=1) == y_train).float().mean().item() * 100
val_acc = (val_out.argmax(dim=1) == y_val).float().mean().item() * 100
print(f"Epoch {epoch}: Train loss {loss.item():.3f}, Val loss {val_loss.item():.3f}, Train acc {train_acc:.1f}%, Val acc {val_acc:.1f}%")
train_model()