0
0
PytorchHow-ToBeginner · 3 min read

How to Use nn.Transformer in PyTorch: Syntax and Example

Use nn.Transformer in PyTorch by creating an instance with parameters like d_model and nhead, then pass source and target sequences to its forward method. It expects inputs shaped as (sequence_length, batch_size, feature_size) and outputs transformed sequences for tasks like translation or sequence modeling.
📐

Syntax

The nn.Transformer class in PyTorch is initialized with key parameters:

  • d_model: the number of expected features in the input (embedding size).
  • nhead: the number of attention heads.
  • num_encoder_layers and num_decoder_layers: number of layers in encoder and decoder.
  • dim_feedforward: size of the feedforward network inside the transformer.

The forward method takes src and tgt tensors with shape (sequence_length, batch_size, d_model).

python
import torch.nn as nn

transformer = nn.Transformer(
    d_model=512, 
    nhead=8, 
    num_encoder_layers=6, 
    num_decoder_layers=6, 
    dim_feedforward=2048
)

output = transformer(src, tgt)
💻

Example

This example shows how to create a simple transformer, generate random input data, and run a forward pass to get the output.

python
import torch
import torch.nn as nn

# Parameters
seq_len_src = 10
seq_len_tgt = 20
batch_size = 32
d_model = 512
nhead = 8

# Create transformer
transformer = nn.Transformer(d_model=d_model, nhead=nhead)

# Random source and target sequences (seq_len, batch, feature)
src = torch.rand(seq_len_src, batch_size, d_model)
tgt = torch.rand(seq_len_tgt, batch_size, d_model)

# Forward pass
output = transformer(src, tgt)

print(f"Output shape: {output.shape}")
Output
Output shape: torch.Size([20, 32, 512])
⚠️

Common Pitfalls

  • Input shape mismatch: The transformer expects inputs shaped as (sequence_length, batch_size, feature_size), not (batch_size, sequence_length, feature_size).
  • Missing masks: For tasks like language modeling, you often need to provide masks to prevent attention to future tokens.
  • Embedding size mismatch: The input feature size must match d_model.
python
import torch
import torch.nn as nn

# Wrong input shape (batch first) - will cause error or wrong behavior
try:
    src_wrong = torch.rand(32, 10, 512)  # batch, seq, feature
    tgt_wrong = torch.rand(32, 20, 512)
    transformer = nn.Transformer(d_model=512, nhead=8)
    output_wrong = transformer(src_wrong, tgt_wrong)
except Exception as e:
    print(f"Error with wrong input shape: {e}")

# Correct input shape
src_correct = torch.rand(10, 32, 512)
tgt_correct = torch.rand(20, 32, 512)
output_correct = transformer(src_correct, tgt_correct)
print(f"Output shape with correct input: {output_correct.shape}")
Output
Error with wrong input shape: Expected src_mask to have shape [10, 10], but got [32, 10] Output shape with correct input: torch.Size([20, 32, 512])
📊

Quick Reference

Remember these tips when using nn.Transformer:

  • Input shape must be (sequence_length, batch_size, d_model).
  • Use src_mask and tgt_mask to control attention flow.
  • Set d_model to match your embedding size.
  • Use nn.TransformerEncoder or nn.TransformerDecoder for encoder-only or decoder-only models.

Key Takeaways

Initialize nn.Transformer with d_model and nhead matching your data embedding size and attention heads.
Inputs must be shaped (sequence_length, batch_size, d_model) for correct processing.
Use masks to prevent attention to future tokens when needed.
nn.Transformer combines encoder and decoder; use encoder or decoder classes for simpler models.
Check input shapes carefully to avoid runtime errors.