0
0
PyTorchml~5 mins

Self-attention mechanism in PyTorch

Choose your learning style9 modes available
Introduction

Self-attention helps a model focus on important parts of the input when making decisions. It looks at all parts of the input and decides which parts matter most for each piece.

Understanding relationships between words in a sentence for translation.
Finding important features in an image for recognition.
Analyzing time series data where past and future points influence each other.
Improving chatbot responses by focusing on relevant conversation parts.
Syntax
PyTorch
import torch
import torch.nn.functional as F

def self_attention(query, key, value):
    scores = torch.matmul(query, key.transpose(-2, -1)) / (query.size(-1) ** 0.5)
    weights = F.softmax(scores, dim=-1)
    output = torch.matmul(weights, value)
    return output, weights

query, key, and value are tensors representing the input in different ways.

The scores measure how much each part should pay attention to others.

Examples
Simple example where query, key, and value are single vectors.
PyTorch
query = torch.tensor([[1., 0., 1.]])
key = torch.tensor([[1., 0., 1.]])
value = torch.tensor([[0., 1., 0.]])
output, weights = self_attention(query, key, value)
Batch of 2 sequences, each with 3 items and 4 features.
PyTorch
query = torch.randn(2, 3, 4)  # batch=2, seq_len=3, features=4
key = torch.randn(2, 3, 4)
value = torch.randn(2, 3, 4)
output, weights = self_attention(query, key, value)
Sample Model

This program computes self-attention on a small example with 3 items in a sequence. It prints the attention weights and the output vectors.

PyTorch
import torch
import torch.nn.functional as F

def self_attention(query, key, value):
    scores = torch.matmul(query, key.transpose(-2, -1)) / (query.size(-1) ** 0.5)
    weights = F.softmax(scores, dim=-1)
    output = torch.matmul(weights, value)
    return output, weights

# Example input: batch size 1, sequence length 3, feature size 4
query = torch.tensor([[[1., 0., 1., 0.],
                       [0., 2., 0., 2.],
                       [1., 1., 1., 1.]]])
key = query.clone()
value = torch.tensor([[[0., 1., 0., 1.],
                       [1., 0., 1., 0.],
                       [0., 0., 1., 1.]]])

output, weights = self_attention(query, key, value)

print("Attention weights:", weights)
print("Output:", output)
OutputSuccess
Important Notes

Self-attention uses the same input for query, key, and value to find relationships within the input.

Dividing by the square root of the feature size helps keep the scores stable.

Softmax turns scores into probabilities that sum to 1, showing importance.

Summary

Self-attention helps models focus on important parts of input data.

It compares each part to every other part using query, key, and value.

Outputs are weighted sums showing combined important information.