0
0
PytorchHow-ToIntermediate · 4 min read

How to Implement Self Attention in PyTorch: Simple Guide

To implement self-attention in PyTorch, create query, key, and value projections from input tensors, compute attention scores by multiplying queries with keys, apply softmax to get weights, and multiply weights by values to get the output. This process captures relationships within the input sequence.
📐

Syntax

Self-attention involves these main steps:

  • Query, Key, Value: Linear layers transform input into these three tensors.
  • Attention Scores: Compute by matrix multiplying Query and Key transpose.
  • Softmax: Normalize scores to weights.
  • Output: Multiply weights by Value tensor.
python
import torch
import torch.nn as nn
import torch.nn.functional as F

class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, queries, mask=None):
        N = queries.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], queries.shape[1]

        # Split embedding into self.heads pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = queries.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        # Einsum does batch matrix multiplication for query*keys for each head
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(energy / (self.head_dim ** 0.5), dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )

        out = self.fc_out(out)
        return out
💻

Example

This example shows how to create a self-attention layer and run a dummy input through it. It demonstrates splitting into heads, computing attention, and producing output of the same shape as input.

python
import torch

embed_size = 8
heads = 2
batch_size = 1
seq_length = 4

self_attention = SelfAttention(embed_size, heads)

# Dummy input: batch size 1, sequence length 4, embedding size 8
x = torch.rand(batch_size, seq_length, embed_size)

# Self-attention uses same tensor for queries, keys, values
output = self_attention(x, x, x)

print("Input shape:", x.shape)
print("Output shape:", output.shape)
print("Output tensor:", output)
Output
Input shape: torch.Size([1, 4, 8]) Output shape: torch.Size([1, 4, 8]) Output tensor: tensor([[[-0.0232, 0.0317, 0.0423, 0.0447, -0.0277, 0.0103, 0.0046, 0.0097], [-0.0037, 0.0229, 0.0347, 0.0387, -0.0228, 0.0123, 0.0063, 0.0113], [-0.0106, 0.0275, 0.0384, 0.0413, -0.0250, 0.0113, 0.0053, 0.0103], [-0.0073, 0.0253, 0.0363, 0.0393, -0.0238, 0.0117, 0.0057, 0.0107]]], grad_fn=<AddBackward0>)
⚠️

Common Pitfalls

  • Dimension mismatch: Forgetting to reshape tensors for multiple heads causes errors.
  • Scaling factor: Missing the division by square root of embedding size leads to unstable gradients.
  • Masking: Not applying masks when needed can cause attention to focus on padding tokens.
  • Bias in linear layers: Usually set bias=False for queries, keys, values to simplify.
python
import torch
import torch.nn as nn

# Wrong: Not reshaping for heads
class WrongSelfAttention(nn.Module):
    def __init__(self, embed_size):
        super().__init__()
        self.query = nn.Linear(embed_size, embed_size)
        self.key = nn.Linear(embed_size, embed_size)
        self.value = nn.Linear(embed_size, embed_size)

    def forward(self, x):
        queries = self.query(x)
        keys = self.key(x)
        values = self.value(x)
        # Incorrect: directly multiply without splitting heads
        scores = torch.matmul(queries, keys.transpose(-2, -1))
        attention = torch.softmax(scores, dim=-1)
        out = torch.matmul(attention, values)
        return out

# Right: Reshape and scale
class RightSelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super().__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, x):
        N, seq_length, _ = x.shape
        x = x.reshape(N, seq_length, self.heads, self.head_dim)
        values = self.values(x)
        keys = self.keys(x)
        queries = self.queries(x)

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        attention = torch.softmax(energy / (self.head_dim ** 0.5), dim=3)
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(N, seq_length, self.embed_size)
        out = self.fc_out(out)
        return out
📊

Quick Reference

  • Use linear layers to create queries, keys, and values from input.
  • Split embeddings into multiple heads for parallel attention.
  • Compute attention scores with scaled dot-product.
  • Apply softmax to get attention weights.
  • Multiply weights by values and combine heads.
  • Use masking to ignore padding tokens if needed.

Key Takeaways

Self-attention computes relationships by comparing queries and keys, then weighting values accordingly.
Always reshape inputs to handle multiple heads and scale attention scores by sqrt of embedding size.
Use masking to prevent attention on irrelevant tokens like padding.
Linear layers without bias are typical for queries, keys, and values in self-attention.
Test your implementation with dummy inputs to verify output shapes and values.