How to Handle Packed Sequences in PyTorch Correctly
torch.nn.utils.rnn.pack_padded_sequence to convert padded sequences before feeding them to RNNs, and use torch.nn.utils.rnn.pad_packed_sequence to convert outputs back to padded tensors. Always keep track of sequence lengths and ensure sequences are sorted by length in descending order before packing.Why This Happens
When working with sequences of different lengths, PyTorch requires packing sequences to efficiently process them with RNNs. If you feed padded sequences directly without packing, or forget to sort sequences by length, you get errors or incorrect outputs.
import torch from torch.nn.utils.rnn import pack_padded_sequence # Sample padded sequences (batch_size=2, max_len=4) sequences = torch.tensor([[1, 2, 3, 0], [4, 5, 0, 0]]) lengths = torch.tensor([3, 2]) # Incorrect: feeding padded sequences directly to RNN rnn = torch.nn.RNN(input_size=1, hidden_size=2, batch_first=True) output, hidden = rnn(sequences.float().unsqueeze(-1))
The Fix
First, sort sequences by length in descending order. Then use pack_padded_sequence to pack the padded sequences before passing to the RNN. After the RNN, use pad_packed_sequence to get back padded outputs.
import torch from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence # Sample padded sequences and lengths sequences = torch.tensor([[1, 2, 3, 0], [4, 5, 0, 0]]) lengths = torch.tensor([3, 2]) # Sort by length descending lengths, perm_idx = lengths.sort(descending=True) sequences = sequences[perm_idx] # Pack sequences packed = pack_padded_sequence(sequences.float().unsqueeze(-1), lengths, batch_first=True) # Define RNN rnn = torch.nn.RNN(input_size=1, hidden_size=2, batch_first=True) # Forward pass packed_output, hidden = rnn(packed) # Unpack output output, _ = pad_packed_sequence(packed_output, batch_first=True) print(output)
Prevention
Always keep track of sequence lengths and sort sequences by length in descending order before packing. Use pack_padded_sequence before RNNs and pad_packed_sequence after to handle variable-length sequences properly. This avoids shape mismatches and incorrect outputs.
Use helper functions or wrappers to automate sorting and packing in your data pipeline.
Related Errors
RuntimeError: Expected input batch_size to be equal to hidden batch_size often means you fed padded sequences directly without packing.
ValueError: lengths array must be sorted in descending order means you forgot to sort lengths before packing.
Fix these by sorting lengths and using pack_padded_sequence correctly.