Self-attention helps a model focus on important parts of a sentence when understanding language. Multi-head attention lets the model look at the sentence from different views at the same time.
Self-attention and multi-head attention in NLP
Attention(Q, K, V) = softmax((Q * K^T) / sqrt(d_k)) * V MultiHead(Q, K, V) = Concat(head_1, ..., head_h) * W_O where head_i = Attention(Q * W_Qi, K * W_Ki, V * W_Vi)
Q, K, V stand for Query, Key, and Value matrices derived from the input.
Multi-head attention runs several attention calculations in parallel, then combines their results.
Q = input_embeddings K = input_embeddings V = input_embeddings output = Attention(Q, K, V)
head_1 = Attention(Q * W_Q1, K * W_K1, V * W_V1) head_2 = Attention(Q * W_Q2, K * W_K2, V * W_V2) output = Concat(head_1, head_2) * W_O
This code creates a simple self-attention layer with two heads. It takes a small input tensor and computes the self-attention output.
import torch import torch.nn as nn import torch.nn.functional as F class SelfAttention(nn.Module): def __init__(self, embed_size, heads): super(SelfAttention, self).__init__() self.embed_size = embed_size self.heads = heads self.head_dim = embed_size // heads assert ( self.head_dim * heads == embed_size ), "Embedding size needs to be divisible by heads" self.values = nn.Linear(self.head_dim, self.head_dim, bias=False) self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False) self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False) self.fc_out = nn.Linear(heads * self.head_dim, embed_size) def forward(self, values, keys, queries): N = queries.shape[0] value_len, key_len, query_len = values.shape[1], keys.shape[1], queries.shape[1] # Split embedding into self.heads pieces values = values.reshape(N, value_len, self.heads, self.head_dim) keys = keys.reshape(N, key_len, self.heads, self.head_dim) queries = queries.reshape(N, query_len, self.heads, self.head_dim) values = self.values(values) keys = self.keys(keys) queries = self.queries(queries) # Einsum does batch matrix multiplication for query*keys for each training example energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys]) # Scale energy energy = energy / (self.head_dim ** 0.5) attention = torch.softmax(energy, dim=3) out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape( N, query_len, self.heads * self.head_dim ) out = self.fc_out(out) return out # Example usage embed_size = 8 heads = 2 self_attention = SelfAttention(embed_size, heads) # Batch size 1, sequence length 3, embedding size 8 x = torch.tensor([[[1., 0., 1., 0., 1., 0., 1., 0.], [0., 1., 0., 1., 0., 1., 0., 1.], [1., 1., 1., 1., 1., 1., 1., 1.]]]) output = self_attention(x, x, x) print(output)
Self-attention helps the model understand relationships between words regardless of their position.
Multi-head attention allows the model to capture different types of relationships at once.
Embedding size must be divisible by the number of heads for splitting.
Self-attention lets a model focus on important words in a sentence by comparing all words to each other.
Multi-head attention runs several self-attention processes in parallel to get richer understanding.
This technique is key in modern language models like Transformers.