How to Use nn.MultiheadAttention in PyTorch: Syntax and Example
Use
nn.MultiheadAttention in PyTorch by creating an instance with the embedding dimension and number of heads, then call it with query, key, and value tensors. It returns the attention output and weights, useful for capturing relationships in sequences.Syntax
The nn.MultiheadAttention module requires specifying the embedding dimension and number of attention heads. You call it with query, key, and value tensors, all shaped as (sequence_length, batch_size, embed_dim). It returns the attended output and attention weights.
- embed_dim: Size of each input embedding vector.
- num_heads: Number of parallel attention heads.
- query, key, value: Input tensors for attention.
- attn_output: Output tensor after attention.
- attn_output_weights: Attention weights showing focus.
python
mha = torch.nn.MultiheadAttention(embed_dim=8, num_heads=2) query = torch.rand(5, 3, 8) # (seq_len=5, batch=3, embed_dim=8) key = torch.rand(5, 3, 8) value = torch.rand(5, 3, 8) attn_output, attn_output_weights = mha(query, key, value)
Example
This example shows how to create a MultiheadAttention layer, prepare random input tensors, and get the attention output and weights. It demonstrates the shape and usage clearly.
python
import torch import torch.nn as nn # Create MultiheadAttention with embedding size 8 and 2 heads mha = nn.MultiheadAttention(embed_dim=8, num_heads=2) # Random input: sequence length 5, batch size 3, embedding 8 query = torch.rand(5, 3, 8) key = torch.rand(5, 3, 8) value = torch.rand(5, 3, 8) # Forward pass attn_output, attn_weights = mha(query, key, value) print('Attention output shape:', attn_output.shape) print('Attention weights shape:', attn_weights.shape)
Output
Attention output shape: torch.Size([5, 3, 8])
Attention weights shape: torch.Size([3, 5, 5])
Common Pitfalls
- Shape mismatch: Inputs must be (sequence_length, batch_size, embed_dim), not batch first.
- Embedding dimension: Must be divisible by number of heads.
- Using batch_first: By default,
MultiheadAttentionexpects sequence first; setbatch_first=Trueif your data is batch first. - Ignoring attention weights: They help understand what the model focuses on.
python
import torch import torch.nn as nn # Wrong: batch first input without batch_first=True mha = nn.MultiheadAttention(embed_dim=8, num_heads=2) query = torch.rand(3, 5, 8) # batch first (batch=3, seq=5, embed=8) key = torch.rand(3, 5, 8) value = torch.rand(3, 5, 8) try: attn_output, attn_weights = mha(query, key, value) except Exception as e: print('Error:', e) # Right: specify batch_first=True mha = nn.MultiheadAttention(embed_dim=8, num_heads=2, batch_first=True) attn_output, attn_weights = mha(query, key, value) print('Output shape with batch_first=True:', attn_output.shape)
Output
Error: Expected query, key, value to be (L, N, E), but got (N, L, E)
Output shape with batch_first=True: torch.Size([3, 5, 8])
Quick Reference
| Parameter | Description |
|---|---|
| embed_dim | Embedding size of input vectors |
| num_heads | Number of attention heads (embed_dim must be divisible by this) |
| batch_first | If True, input shape is (batch, seq, embed) instead of (seq, batch, embed) |
| dropout | Dropout probability on attention weights (optional) |
| attn_mask | Optional mask to prevent attention to certain positions |
| key_padding_mask | Optional mask to ignore padding tokens in attention |
Key Takeaways
Create nn.MultiheadAttention with embedding size and number of heads before use.
Input tensors must have shape (sequence_length, batch_size, embed_dim) unless batch_first=True.
Embedding dimension must be divisible by the number of heads.
The module returns both the attention output and the attention weights.
Use attention weights to understand what parts of the input the model focuses on.