Self-attention calculates how much each word in a sentence should attend to every other word. It does this by computing similarity scores and then using these scores as weights to combine the input elements.
import torch import torch.nn as nn batch_size = 2 seq_len = 5 embed_dim = 16 num_heads = 4 mha = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads) query = torch.rand(seq_len, batch_size, embed_dim) key = torch.rand(seq_len, batch_size, embed_dim) value = torch.rand(seq_len, batch_size, embed_dim) output, _ = mha(query, key, value) print(output.shape)
PyTorch's MultiheadAttention expects inputs with shape (sequence_length, batch_size, embedding_dim) and outputs the same shape.
Increasing the number of heads splits the embedding dimension into smaller parts per head. This allows each head to learn different aspects or subspaces of the input features.
High attention weights mean the model is focusing more on the related token when processing the current token, indicating importance or relevance.
import torch import torch.nn as nn import torch.nn.functional as F class SimpleMultiHeadAttention(nn.Module): def __init__(self, embed_dim, num_heads): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads self.q_linear = nn.Linear(embed_dim, embed_dim) self.k_linear = nn.Linear(embed_dim, embed_dim) self.v_linear = nn.Linear(embed_dim, embed_dim) self.out_linear = nn.Linear(embed_dim, embed_dim) def forward(self, x): batch_size, seq_len, embed_dim = x.size() Q = self.q_linear(x) K = self.k_linear(x) V = self.v_linear(x) # reshape for multi-head Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim) K = K.view(batch_size, seq_len, self.num_heads, self.head_dim) V = V.view(batch_size, seq_len, self.num_heads, self.head_dim) # transpose to get dimensions batch_size, num_heads, seq_len, head_dim Q = Q.transpose(1, 2) K = K.transpose(1, 2) V = V.transpose(1, 2) scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5) attn = F.softmax(scores, dim=-1) out = torch.matmul(attn, V) # concatenate heads out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim) out = self.out_linear(out) return out x = torch.rand(2, 5, 16) model = SimpleMultiHeadAttention(embed_dim=16, num_heads=4) output = model(x)
The code correctly reshapes and transposes the tensors for multi-head attention. The matrix multiplications have matching dimensions, and the final output shape matches the input embedding dimension. Therefore, no error occurs.