0
0
NLPml~5 mins

Self-attention and multi-head attention in NLP

Choose your learning style9 modes available
Introduction

Self-attention helps a model focus on important parts of a sentence when understanding language. Multi-head attention lets the model look at the sentence from different views at the same time.

When translating a sentence from one language to another.
When summarizing a long article into a short paragraph.
When answering questions based on a paragraph of text.
When recognizing the meaning of words depending on context.
When building chatbots that understand user messages.
Syntax
NLP
Attention(Q, K, V) = softmax((Q * K^T) / sqrt(d_k)) * V

MultiHead(Q, K, V) = Concat(head_1, ..., head_h) * W_O
where head_i = Attention(Q * W_Qi, K * W_Ki, V * W_Vi)

Q, K, V stand for Query, Key, and Value matrices derived from the input.

Multi-head attention runs several attention calculations in parallel, then combines their results.

Examples
This is self-attention where the input attends to itself.
NLP
Q = input_embeddings
K = input_embeddings
V = input_embeddings
output = Attention(Q, K, V)
This shows multi-head attention with two heads looking at the input differently.
NLP
head_1 = Attention(Q * W_Q1, K * W_K1, V * W_V1)
head_2 = Attention(Q * W_Q2, K * W_K2, V * W_V2)
output = Concat(head_1, head_2) * W_O
Sample Model

This code creates a simple self-attention layer with two heads. It takes a small input tensor and computes the self-attention output.

NLP
import torch
import torch.nn as nn
import torch.nn.functional as F

class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, queries):
        N = queries.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], queries.shape[1]

        # Split embedding into self.heads pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = queries.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        # Einsum does batch matrix multiplication for query*keys for each training example
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

        # Scale energy
        energy = energy / (self.head_dim ** 0.5)

        attention = torch.softmax(energy, dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )

        out = self.fc_out(out)
        return out

# Example usage
embed_size = 8
heads = 2
self_attention = SelfAttention(embed_size, heads)

# Batch size 1, sequence length 3, embedding size 8
x = torch.tensor([[[1., 0., 1., 0., 1., 0., 1., 0.],
                   [0., 1., 0., 1., 0., 1., 0., 1.],
                   [1., 1., 1., 1., 1., 1., 1., 1.]]])

output = self_attention(x, x, x)
print(output)
OutputSuccess
Important Notes

Self-attention helps the model understand relationships between words regardless of their position.

Multi-head attention allows the model to capture different types of relationships at once.

Embedding size must be divisible by the number of heads for splitting.

Summary

Self-attention lets a model focus on important words in a sentence by comparing all words to each other.

Multi-head attention runs several self-attention processes in parallel to get richer understanding.

This technique is key in modern language models like Transformers.