0
0
PyTorchml~5 mins

Multi-head attention in PyTorch

Choose your learning style9 modes available
Introduction

Multi-head attention helps a model look at information from different views at the same time. This makes it better at understanding complex data like sentences or images.

When building language translation models to understand words in context.
When creating chatbots that need to understand multiple parts of a conversation.
When analyzing images where different parts relate to each other.
When summarizing long documents by focusing on important sections.
When improving recommendation systems by looking at various user behaviors.
Syntax
PyTorch
torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)

embed_dim: Size of each input embedding vector.

num_heads: How many different attention views to use.

Examples
Creates a multi-head attention layer with 64-dimensional embeddings and 8 heads.
PyTorch
mha = torch.nn.MultiheadAttention(embed_dim=64, num_heads=8)
Creates a multi-head attention with dropout and input/output shaped as (batch, seq, feature).
PyTorch
mha = torch.nn.MultiheadAttention(embed_dim=128, num_heads=4, dropout=0.1, batch_first=True)
Sample Model

This code creates a multi-head attention layer with 4 heads and 16-dimensional embeddings. It runs random data through it and prints the shapes and small samples of the output and attention weights.

PyTorch
import torch
import torch.nn as nn

# Set seed for reproducibility
torch.manual_seed(0)

# Parameters
embed_dim = 16
num_heads = 4
seq_length = 5
batch_size = 2

# Create MultiheadAttention layer
mha = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)

# Random input tensors (batch, seq, embed_dim)
query = torch.rand(batch_size, seq_length, embed_dim)
key = torch.rand(batch_size, seq_length, embed_dim)
value = torch.rand(batch_size, seq_length, embed_dim)

# Forward pass
output, attn_weights = mha(query, key, value)

# Print shapes and a small part of output and attention weights
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attn_weights.shape}")
print(f"Output sample:\n{output[0, :2, :4]}")
print(f"Attention weights sample:\n{attn_weights[0, :2, :5]}")
OutputSuccess
Important Notes

Input shapes must match the expected format: (batch, sequence, embedding) if batch_first=True.

Multi-head attention splits embeddings into parts, applies attention separately, then combines results.

Attention weights show how much each position focuses on others.

Summary

Multi-head attention lets models focus on different parts of data at once.

It improves understanding of complex patterns in language and images.

PyTorch provides a simple MultiheadAttention layer to use this easily.