How to Build a Transformer Model in PyTorch: Simple Guide
To build a
Transformer in PyTorch, use the torch.nn.Transformer module which provides the full transformer architecture. Define input embeddings, positional encodings, and pass them through the transformer layers, then decode the output for your task.Syntax
The torch.nn.Transformer class creates a transformer model with encoder and decoder layers. Key parts include:
d_model: size of input embeddingsnhead: number of attention headsnum_encoder_layersandnum_decoder_layers: number of layers in encoder and decoderdim_feedforward: size of the feedforward network inside each layerforward(src, tgt): method to pass source and target sequences through the model
python
import torch import torch.nn as nn transformer = nn.Transformer( d_model=512, # embedding size nhead=8, # attention heads num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048 ) # src and tgt shape: (sequence_length, batch_size, d_model) src = torch.rand((10, 32, 512)) tgt = torch.rand((20, 32, 512)) output = transformer(src, tgt) # output shape: (20, 32, 512)
Example
This example shows a minimal transformer for sequence-to-sequence prediction with random data. It includes input embedding, positional encoding, and the transformer model.
python
import torch import torch.nn as nn import math class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=5000): super().__init__() pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(1) # (max_len, 1, d_model) self.register_buffer('pe', pe) def forward(self, x): x = x + self.pe[:x.size(0)] return x class SimpleTransformer(nn.Module): def __init__(self, input_dim, d_model, nhead, num_layers, output_dim): super().__init__() self.embedding = nn.Linear(input_dim, d_model) self.pos_encoder = PositionalEncoding(d_model) self.transformer = nn.Transformer(d_model=d_model, nhead=nhead, num_encoder_layers=num_layers, num_decoder_layers=num_layers) self.fc_out = nn.Linear(d_model, output_dim) def forward(self, src, tgt): src = self.embedding(src) # (seq_len, batch, d_model) tgt = self.embedding(tgt) src = self.pos_encoder(src) tgt = self.pos_encoder(tgt) output = self.transformer(src, tgt) output = self.fc_out(output) return output # Example usage seq_len_src, seq_len_tgt, batch_size, input_dim = 10, 20, 32, 16 model = SimpleTransformer(input_dim=input_dim, d_model=64, nhead=8, num_layers=3, output_dim=10) src = torch.rand(seq_len_src, batch_size, input_dim) tgt = torch.rand(seq_len_tgt, batch_size, input_dim) output = model(src, tgt) print(output.shape) # Expected: (seq_len_tgt, batch_size, output_dim)
Output
(20, 32, 10)
Common Pitfalls
Common mistakes when building transformers in PyTorch include:
- Not matching input shapes:
nn.Transformerexpects inputs shaped (sequence_length, batch_size, embedding_dim). - Forgetting positional encoding: Transformers need positional info added to embeddings.
- Using incorrect mask shapes or forgetting masks for causal decoding.
- Confusing encoder and decoder inputs: source goes to encoder, target to decoder.
python
import torch import torch.nn as nn # Wrong: input shape (batch_size, seq_len, d_model) instead of (seq_len, batch_size, d_model) transformer = nn.Transformer(d_model=512, nhead=8) src_wrong = torch.rand((32, 10, 512)) # batch first try: output = transformer(src_wrong, src_wrong) except Exception as e: print(f'Error: {e}') # Right: transpose to (seq_len, batch_size, d_model) src_right = src_wrong.transpose(0, 1) output = transformer(src_right, src_right) print('Output shape:', output.shape)
Output
Error: Expected src_mask to have shape [10, 10], but got [32, 10]
Output shape: torch.Size([10, 32, 512])
Quick Reference
Tips for building transformers in PyTorch:
- Use
nn.Transformerfor full transformer architecture. - Input shape must be (sequence_length, batch_size, embedding_dim).
- Add positional encoding to input embeddings.
- Use masks for decoder to prevent attending to future tokens.
- Adjust
d_model,nhead, and layers for your task size.
Key Takeaways
Use torch.nn.Transformer with correct input shapes (seq_len, batch, d_model).
Always add positional encoding to input embeddings before the transformer.
Provide source sequence to encoder and target sequence to decoder.
Use masks to control attention flow, especially for autoregressive decoding.
Adjust model size parameters to balance performance and speed.