import torch
import torch.nn as nn
import torch.nn.functional as F
class ScaledDotProductAttention(nn.Module):
def __init__(self, temperature):
super().__init__()
self.temperature = temperature
def forward(self, q, k, v, mask=None):
attn = torch.matmul(q, k.transpose(-2, -1)) / self.temperature
if mask is not None:
attn = attn.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(attn, dim=-1)
output = torch.matmul(attn, v)
return output, attn
class MultiHeadAttention(nn.Module):
def __init__(self, n_head, d_model, d_k, d_v):
super().__init__()
self.n_head = n_head
self.d_k = d_k
self.d_v = d_v
self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
self.fc = nn.Linear(n_head * d_v, d_model, bias=False)
self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5)
def forward(self, q, k, v, mask=None):
batch_size, len_q, _ = q.size()
len_k = k.size(1)
len_v = v.size(1)
q = self.w_qs(q).view(batch_size, len_q, self.n_head, self.d_k).transpose(1, 2)
k = self.w_ks(k).view(batch_size, len_k, self.n_head, self.d_k).transpose(1, 2)
v = self.w_vs(v).view(batch_size, len_v, self.n_head, self.d_v).transpose(1, 2)
if mask is not None:
mask = mask.unsqueeze(1)
output, attn = self.attention(q, k, v, mask=mask)
output = output.transpose(1, 2).contiguous().view(batch_size, len_q, -1)
output = self.fc(output)
return output, attn
# Example usage with dummy data
batch_size = 2
seq_len = 5
d_model = 16
n_head = 4
d_k = d_v = d_model // n_head
x = torch.rand(batch_size, seq_len, d_model) # input embeddings
mha = MultiHeadAttention(n_head=n_head, d_model=d_model, d_k=d_k, d_v=d_v)
output, attention_weights = mha(x, x, x)
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attention_weights.shape}")