How to Use nn.Transformer in PyTorch: Syntax and Example
Use
nn.Transformer in PyTorch by creating an instance with parameters like d_model and nhead, then pass source and target sequences to its forward method. It expects inputs shaped as (sequence_length, batch_size, feature_size) and outputs transformed sequences for tasks like translation or sequence modeling.Syntax
The nn.Transformer class in PyTorch is initialized with key parameters:
d_model: the number of expected features in the input (embedding size).nhead: the number of attention heads.num_encoder_layersandnum_decoder_layers: number of layers in encoder and decoder.dim_feedforward: size of the feedforward network inside the transformer.
The forward method takes src and tgt tensors with shape (sequence_length, batch_size, d_model).
python
import torch.nn as nn transformer = nn.Transformer( d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048 ) output = transformer(src, tgt)
Example
This example shows how to create a simple transformer, generate random input data, and run a forward pass to get the output.
python
import torch import torch.nn as nn # Parameters seq_len_src = 10 seq_len_tgt = 20 batch_size = 32 d_model = 512 nhead = 8 # Create transformer transformer = nn.Transformer(d_model=d_model, nhead=nhead) # Random source and target sequences (seq_len, batch, feature) src = torch.rand(seq_len_src, batch_size, d_model) tgt = torch.rand(seq_len_tgt, batch_size, d_model) # Forward pass output = transformer(src, tgt) print(f"Output shape: {output.shape}")
Output
Output shape: torch.Size([20, 32, 512])
Common Pitfalls
- Input shape mismatch: The transformer expects inputs shaped as (sequence_length, batch_size, feature_size), not (batch_size, sequence_length, feature_size).
- Missing masks: For tasks like language modeling, you often need to provide masks to prevent attention to future tokens.
- Embedding size mismatch: The input feature size must match
d_model.
python
import torch import torch.nn as nn # Wrong input shape (batch first) - will cause error or wrong behavior try: src_wrong = torch.rand(32, 10, 512) # batch, seq, feature tgt_wrong = torch.rand(32, 20, 512) transformer = nn.Transformer(d_model=512, nhead=8) output_wrong = transformer(src_wrong, tgt_wrong) except Exception as e: print(f"Error with wrong input shape: {e}") # Correct input shape src_correct = torch.rand(10, 32, 512) tgt_correct = torch.rand(20, 32, 512) output_correct = transformer(src_correct, tgt_correct) print(f"Output shape with correct input: {output_correct.shape}")
Output
Error with wrong input shape: Expected src_mask to have shape [10, 10], but got [32, 10]
Output shape with correct input: torch.Size([20, 32, 512])
Quick Reference
Remember these tips when using nn.Transformer:
- Input shape must be (sequence_length, batch_size, d_model).
- Use
src_maskandtgt_maskto control attention flow. - Set
d_modelto match your embedding size. - Use
nn.TransformerEncoderornn.TransformerDecoderfor encoder-only or decoder-only models.
Key Takeaways
Initialize nn.Transformer with d_model and nhead matching your data embedding size and attention heads.
Inputs must be shaped (sequence_length, batch_size, d_model) for correct processing.
Use masks to prevent attention to future tokens when needed.
nn.Transformer combines encoder and decoder; use encoder or decoder classes for simpler models.
Check input shapes carefully to avoid runtime errors.