0
0
PytorchHow-ToBeginner · 4 min read

How to Use nn.MultiheadAttention in PyTorch: Syntax and Example

Use nn.MultiheadAttention in PyTorch by creating an instance with the embedding dimension and number of heads, then call it with query, key, and value tensors. It returns the attention output and weights, useful for capturing relationships in sequences.
📐

Syntax

The nn.MultiheadAttention module requires specifying the embedding dimension and number of attention heads. You call it with query, key, and value tensors, all shaped as (sequence_length, batch_size, embed_dim). It returns the attended output and attention weights.

  • embed_dim: Size of each input embedding vector.
  • num_heads: Number of parallel attention heads.
  • query, key, value: Input tensors for attention.
  • attn_output: Output tensor after attention.
  • attn_output_weights: Attention weights showing focus.
python
mha = torch.nn.MultiheadAttention(embed_dim=8, num_heads=2)
query = torch.rand(5, 3, 8)  # (seq_len=5, batch=3, embed_dim=8)
key = torch.rand(5, 3, 8)
value = torch.rand(5, 3, 8)
attn_output, attn_output_weights = mha(query, key, value)
💻

Example

This example shows how to create a MultiheadAttention layer, prepare random input tensors, and get the attention output and weights. It demonstrates the shape and usage clearly.

python
import torch
import torch.nn as nn

# Create MultiheadAttention with embedding size 8 and 2 heads
mha = nn.MultiheadAttention(embed_dim=8, num_heads=2)

# Random input: sequence length 5, batch size 3, embedding 8
query = torch.rand(5, 3, 8)
key = torch.rand(5, 3, 8)
value = torch.rand(5, 3, 8)

# Forward pass
attn_output, attn_weights = mha(query, key, value)

print('Attention output shape:', attn_output.shape)
print('Attention weights shape:', attn_weights.shape)
Output
Attention output shape: torch.Size([5, 3, 8]) Attention weights shape: torch.Size([3, 5, 5])
⚠️

Common Pitfalls

  • Shape mismatch: Inputs must be (sequence_length, batch_size, embed_dim), not batch first.
  • Embedding dimension: Must be divisible by number of heads.
  • Using batch_first: By default, MultiheadAttention expects sequence first; set batch_first=True if your data is batch first.
  • Ignoring attention weights: They help understand what the model focuses on.
python
import torch
import torch.nn as nn

# Wrong: batch first input without batch_first=True
mha = nn.MultiheadAttention(embed_dim=8, num_heads=2)
query = torch.rand(3, 5, 8)  # batch first (batch=3, seq=5, embed=8)
key = torch.rand(3, 5, 8)
value = torch.rand(3, 5, 8)

try:
    attn_output, attn_weights = mha(query, key, value)
except Exception as e:
    print('Error:', e)

# Right: specify batch_first=True
mha = nn.MultiheadAttention(embed_dim=8, num_heads=2, batch_first=True)
attn_output, attn_weights = mha(query, key, value)
print('Output shape with batch_first=True:', attn_output.shape)
Output
Error: Expected query, key, value to be (L, N, E), but got (N, L, E) Output shape with batch_first=True: torch.Size([3, 5, 8])
📊

Quick Reference

ParameterDescription
embed_dimEmbedding size of input vectors
num_headsNumber of attention heads (embed_dim must be divisible by this)
batch_firstIf True, input shape is (batch, seq, embed) instead of (seq, batch, embed)
dropoutDropout probability on attention weights (optional)
attn_maskOptional mask to prevent attention to certain positions
key_padding_maskOptional mask to ignore padding tokens in attention

Key Takeaways

Create nn.MultiheadAttention with embedding size and number of heads before use.
Input tensors must have shape (sequence_length, batch_size, embed_dim) unless batch_first=True.
Embedding dimension must be divisible by the number of heads.
The module returns both the attention output and the attention weights.
Use attention weights to understand what parts of the input the model focuses on.