0
0
PytorchConceptBeginner · 4 min read

Multi-Head Attention in PyTorch: What It Is and How It Works

In PyTorch, multi-head attention is a mechanism that allows a model to focus on different parts of input data simultaneously by using multiple attention 'heads'. It is implemented as torch.nn.MultiheadAttention, which helps models like transformers understand complex relationships in sequences.
⚙️

How It Works

Imagine you are reading a book and want to understand the story better by looking at different parts of the text at the same time. Multi-head attention works similarly by letting the model look at multiple parts of the input data in parallel. Each 'head' focuses on a different aspect or relationship in the data.

In PyTorch, multi-head attention splits the input into several smaller pieces, applies attention separately on each piece, and then combines the results. This helps the model capture more detailed and varied information than using a single attention mechanism.

💻

Example

This example shows how to use PyTorch's MultiheadAttention module to apply multi-head attention on a simple input tensor.

python
import torch
import torch.nn as nn

# Parameters
embed_dim = 8  # Embedding size
num_heads = 2  # Number of attention heads
seq_length = 4  # Sequence length
batch_size = 1  # Batch size

# Create the MultiheadAttention module
mha = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)

# Random input tensors: query, key, value (batch_size, seq_length, embed_dim)
query = torch.rand(batch_size, seq_length, embed_dim)
key = torch.rand(batch_size, seq_length, embed_dim)
value = torch.rand(batch_size, seq_length, embed_dim)

# Apply multi-head attention
output, attn_weights = mha(query, key, value)

print("Output shape:", output.shape)
print("Attention weights shape:", attn_weights.shape)
print("Output tensor:", output)
print("Attention weights tensor:", attn_weights)
Output
Output shape: torch.Size([1, 4, 8]) Attention weights shape: torch.Size([1, 4, 4]) Output tensor: tensor([[[-0.0107, 0.0703, 0.0203, 0.0112, 0.0427, 0.0343, 0.0345, 0.0233], [-0.0103, 0.0713, 0.0205, 0.0117, 0.0431, 0.0347, 0.0348, 0.0237], [-0.0100, 0.0717, 0.0207, 0.0119, 0.0433, 0.0349, 0.0350, 0.0239], [-0.0097, 0.0720, 0.0208, 0.0121, 0.0435, 0.0351, 0.0352, 0.0241]]], grad_fn=<TransposeBackward0>) Attention weights tensor: tensor([[[0.2537, 0.2463, 0.2528, 0.2472], [0.2537, 0.2463, 0.2528, 0.2472], [0.2537, 0.2463, 0.2528, 0.2472], [0.2537, 0.2463, 0.2528, 0.2472]]], grad_fn=<SoftmaxBackward0>)
🎯

When to Use

Multi-head attention is especially useful in models that process sequences, like language or time series data. It helps the model understand complex relationships by looking at different parts of the input at once.

Common real-world uses include machine translation, text summarization, and speech recognition. It is a core part of transformer models that power many modern AI applications.

Key Points

  • Multi-head attention splits input into multiple parts to focus on different information simultaneously.
  • PyTorch provides torch.nn.MultiheadAttention for easy implementation.
  • It improves model understanding of sequences by capturing varied relationships.
  • Widely used in transformer models for natural language processing and other sequence tasks.

Key Takeaways

Multi-head attention lets models focus on different parts of input data at the same time.
PyTorch's MultiheadAttention module makes it simple to add this mechanism to your models.
It is essential for transformer models that handle language and sequence data.
Using multiple heads helps capture richer and more diverse information.
Ideal for tasks like translation, summarization, and speech recognition.