0
0
PyTorchml~5 mins

Transformer decoder in PyTorch

Choose your learning style9 modes available
Introduction

The Transformer decoder helps a model understand and generate sequences, like sentences, by looking at what it has generated so far and what it should pay attention to.

When building a language translator that turns one language into another.
When creating a chatbot that replies to messages step-by-step.
When generating text, like writing stories or summaries.
When predicting the next word in a sentence based on previous words.
Syntax
PyTorch
torch.nn.TransformerDecoder(decoder_layer, num_layers, norm=None)

decoder_layer is a single layer that defines how the decoder works.

num_layers is how many decoder layers you stack to make the full decoder.

Examples
This creates a decoder with 6 layers, each layer has 512 features and 8 attention heads.
PyTorch
decoder_layer = torch.nn.TransformerDecoderLayer(d_model=512, nhead=8)
decoder = torch.nn.TransformerDecoder(decoder_layer, num_layers=6)
Here, tgt is the target sequence input, and memory is the encoded source sequence from the encoder.
PyTorch
output = decoder(tgt, memory)
Sample Model

This code builds a Transformer decoder with 2 layers. It creates random input sequences and passes them through the decoder. The output shape shows the sequence length, batch size, and feature size. The first output vector is printed rounded to 4 decimals.

PyTorch
import torch
from torch import nn

# Set seed for reproducibility
torch.manual_seed(0)

# Define parameters
d_model = 32
nhead = 4
num_layers = 2
seq_len = 5
batch_size = 1

# Create a single decoder layer
decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead)
# Stack decoder layers
decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)

# Create dummy target sequence (tgt) and memory from encoder
# Shape: (sequence length, batch size, feature size)
tgt = torch.rand(seq_len, batch_size, d_model)
memory = torch.rand(seq_len, batch_size, d_model)

# Pass through decoder
output = decoder(tgt, memory)

# Print output shape and first vector
print(f"Output shape: {output.shape}")
print(f"First output vector:\n{output[0,0,:].round(4)}")
OutputSuccess
Important Notes

The decoder expects inputs shaped as (sequence length, batch size, feature size).

The memory input comes from the encoder output and helps the decoder focus on the source sequence.

Masking is often used in real tasks to prevent the decoder from seeing future tokens, but is omitted here for simplicity.

Summary

The Transformer decoder generates output sequences by looking at previous outputs and encoder information.

It is built by stacking decoder layers, each with attention mechanisms.

Use it when you want to create models that generate or predict sequences step-by-step.