0
0
PytorchHow-ToBeginner · 4 min read

How to Build a Transformer Model in PyTorch: Simple Guide

To build a Transformer in PyTorch, use the torch.nn.Transformer module which provides the full transformer architecture. Define input embeddings, positional encodings, and pass them through the transformer layers, then decode the output for your task.
📐

Syntax

The torch.nn.Transformer class creates a transformer model with encoder and decoder layers. Key parts include:

  • d_model: size of input embeddings
  • nhead: number of attention heads
  • num_encoder_layers and num_decoder_layers: number of layers in encoder and decoder
  • dim_feedforward: size of the feedforward network inside each layer
  • forward(src, tgt): method to pass source and target sequences through the model
python
import torch
import torch.nn as nn

transformer = nn.Transformer(
    d_model=512,       # embedding size
    nhead=8,           # attention heads
    num_encoder_layers=6,
    num_decoder_layers=6,
    dim_feedforward=2048
)

# src and tgt shape: (sequence_length, batch_size, d_model)
src = torch.rand((10, 32, 512))
tgt = torch.rand((20, 32, 512))

output = transformer(src, tgt)  # output shape: (20, 32, 512)
💻

Example

This example shows a minimal transformer for sequence-to-sequence prediction with random data. It includes input embedding, positional encoding, and the transformer model.

python
import torch
import torch.nn as nn
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(1)  # (max_len, 1, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return x

class SimpleTransformer(nn.Module):
    def __init__(self, input_dim, d_model, nhead, num_layers, output_dim):
        super().__init__()
        self.embedding = nn.Linear(input_dim, d_model)
        self.pos_encoder = PositionalEncoding(d_model)
        self.transformer = nn.Transformer(d_model=d_model, nhead=nhead, num_encoder_layers=num_layers, num_decoder_layers=num_layers)
        self.fc_out = nn.Linear(d_model, output_dim)

    def forward(self, src, tgt):
        src = self.embedding(src)  # (seq_len, batch, d_model)
        tgt = self.embedding(tgt)
        src = self.pos_encoder(src)
        tgt = self.pos_encoder(tgt)
        output = self.transformer(src, tgt)
        output = self.fc_out(output)
        return output

# Example usage
seq_len_src, seq_len_tgt, batch_size, input_dim = 10, 20, 32, 16
model = SimpleTransformer(input_dim=input_dim, d_model=64, nhead=8, num_layers=3, output_dim=10)
src = torch.rand(seq_len_src, batch_size, input_dim)
tgt = torch.rand(seq_len_tgt, batch_size, input_dim)
output = model(src, tgt)
print(output.shape)  # Expected: (seq_len_tgt, batch_size, output_dim)
Output
(20, 32, 10)
⚠️

Common Pitfalls

Common mistakes when building transformers in PyTorch include:

  • Not matching input shapes: nn.Transformer expects inputs shaped (sequence_length, batch_size, embedding_dim).
  • Forgetting positional encoding: Transformers need positional info added to embeddings.
  • Using incorrect mask shapes or forgetting masks for causal decoding.
  • Confusing encoder and decoder inputs: source goes to encoder, target to decoder.
python
import torch
import torch.nn as nn

# Wrong: input shape (batch_size, seq_len, d_model) instead of (seq_len, batch_size, d_model)
transformer = nn.Transformer(d_model=512, nhead=8)
src_wrong = torch.rand((32, 10, 512))  # batch first
try:
    output = transformer(src_wrong, src_wrong)
except Exception as e:
    print(f'Error: {e}')

# Right: transpose to (seq_len, batch_size, d_model)
src_right = src_wrong.transpose(0, 1)
output = transformer(src_right, src_right)
print('Output shape:', output.shape)
Output
Error: Expected src_mask to have shape [10, 10], but got [32, 10] Output shape: torch.Size([10, 32, 512])
📊

Quick Reference

Tips for building transformers in PyTorch:

  • Use nn.Transformer for full transformer architecture.
  • Input shape must be (sequence_length, batch_size, embedding_dim).
  • Add positional encoding to input embeddings.
  • Use masks for decoder to prevent attending to future tokens.
  • Adjust d_model, nhead, and layers for your task size.

Key Takeaways

Use torch.nn.Transformer with correct input shapes (seq_len, batch, d_model).
Always add positional encoding to input embeddings before the transformer.
Provide source sequence to encoder and target sequence to decoder.
Use masks to control attention flow, especially for autoregressive decoding.
Adjust model size parameters to balance performance and speed.