Multi-head attention helps a model look at information from different views at the same time. This makes it better at understanding complex data like sentences or images.
Multi-head attention in PyTorch
torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)
embed_dim: Size of each input embedding vector.
num_heads: How many different attention views to use.
mha = torch.nn.MultiheadAttention(embed_dim=64, num_heads=8)
mha = torch.nn.MultiheadAttention(embed_dim=128, num_heads=4, dropout=0.1, batch_first=True)
This code creates a multi-head attention layer with 4 heads and 16-dimensional embeddings. It runs random data through it and prints the shapes and small samples of the output and attention weights.
import torch import torch.nn as nn # Set seed for reproducibility torch.manual_seed(0) # Parameters embed_dim = 16 num_heads = 4 seq_length = 5 batch_size = 2 # Create MultiheadAttention layer mha = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True) # Random input tensors (batch, seq, embed_dim) query = torch.rand(batch_size, seq_length, embed_dim) key = torch.rand(batch_size, seq_length, embed_dim) value = torch.rand(batch_size, seq_length, embed_dim) # Forward pass output, attn_weights = mha(query, key, value) # Print shapes and a small part of output and attention weights print(f"Output shape: {output.shape}") print(f"Attention weights shape: {attn_weights.shape}") print(f"Output sample:\n{output[0, :2, :4]}") print(f"Attention weights sample:\n{attn_weights[0, :2, :5]}")
Input shapes must match the expected format: (batch, sequence, embedding) if batch_first=True.
Multi-head attention splits embeddings into parts, applies attention separately, then combines results.
Attention weights show how much each position focuses on others.
Multi-head attention lets models focus on different parts of data at once.
It improves understanding of complex patterns in language and images.
PyTorch provides a simple MultiheadAttention layer to use this easily.