0
0
PytorchHow-ToBeginner · 3 min read

How to Use nn.TransformerEncoder in PyTorch: Syntax and Example

Use nn.TransformerEncoder by first creating a TransformerEncoderLayer that defines one encoder block, then stack it with nn.TransformerEncoder to build the full encoder. Pass your input tensor of shape (sequence_length, batch_size, embedding_dim) through the encoder to get transformed output.
📐

Syntax

The nn.TransformerEncoder requires a TransformerEncoderLayer which defines the architecture of one encoder block. You specify the number of layers to stack these blocks. The input tensor shape must be (sequence_length, batch_size, embedding_dim).

Key parts:

  • TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout): Defines one encoder layer with model dimension, number of attention heads, feedforward size, and dropout.
  • TransformerEncoder(encoder_layer, num_layers): Stacks multiple encoder layers.
  • Input shape: (seq_len, batch_size, embedding_dim).
python
import torch
import torch.nn as nn

# Define one encoder layer
encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)

# Stack 6 such layers to build the encoder
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)

# Input tensor: (sequence_length, batch_size, embedding_dim)
x = torch.rand(10, 32, 512)  # 10 tokens, batch size 32, embedding dim 512

# Forward pass
output = transformer_encoder(x)
💻

Example

This example creates a TransformerEncoder with 2 layers and runs a random input tensor through it. It prints the output shape to confirm the transformation.

python
import torch
import torch.nn as nn

# Create one encoder layer
encoder_layer = nn.TransformerEncoderLayer(d_model=64, nhead=8, dim_feedforward=256, dropout=0.1)

# Stack 2 layers
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)

# Random input: sequence length 5, batch size 3, embedding dim 64
x = torch.rand(5, 3, 64)

# Pass input through encoder
output = transformer_encoder(x)

# Print output shape
print('Output shape:', output.shape)
Output
Output shape: torch.Size([5, 3, 64])
⚠️

Common Pitfalls

  • Wrong input shape: The input must be (sequence_length, batch_size, embedding_dim), not (batch_size, sequence_length, embedding_dim). This is a common mistake.
  • Mismatch in dimensions: The d_model in TransformerEncoderLayer must match the embedding dimension of your input.
  • Forgetting to stack layers: Creating only one TransformerEncoderLayer does not build the full encoder; you must wrap it with TransformerEncoder and specify num_layers.
python
import torch
import torch.nn as nn

# Wrong input shape example
encoder_layer = nn.TransformerEncoderLayer(d_model=64, nhead=8)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=1)

# Incorrect input shape (batch_size, seq_len, embedding_dim)
x_wrong = torch.rand(3, 5, 64)

try:
    output_wrong = transformer_encoder(x_wrong)
except Exception as e:
    print('Error:', e)

# Correct input shape
x_correct = x_wrong.permute(1, 0, 2)  # (seq_len, batch_size, embedding_dim)
output_correct = transformer_encoder(x_correct)
print('Output shape with correct input:', output_correct.shape)
Output
Error: Expected src_mask to have shape [5, 5], but got [3, 5] Output shape with correct input: torch.Size([5, 3, 64])
📊

Quick Reference

ParameterDescription
d_modelEmbedding dimension of input and model
nheadNumber of attention heads in multi-head attention
dim_feedforwardDimension of the feedforward network inside encoder layer
dropoutDropout rate for regularization
num_layersNumber of encoder layers to stack in TransformerEncoder
Input shape(sequence_length, batch_size, embedding_dim)

Key Takeaways

Always create a TransformerEncoderLayer first, then stack it with TransformerEncoder specifying num_layers.
Input tensor shape must be (sequence_length, batch_size, embedding_dim), not batch first.
d_model must match the embedding dimension of your input data.
TransformerEncoder stacks multiple encoder layers to build a deep encoder.
Check input shapes carefully to avoid runtime errors.