How to Use nn.LSTM in PyTorch: Syntax and Example
Use
nn.LSTM in PyTorch by creating an LSTM layer with input and hidden sizes, then pass input tensors through it to get output and hidden states. Initialize the LSTM with nn.LSTM(input_size, hidden_size, num_layers) and call it with input shaped as (seq_len, batch, input_size).Syntax
The nn.LSTM module in PyTorch is initialized with key parameters:
- input_size: Number of expected features in the input.
- hidden_size: Number of features in the hidden state.
- num_layers: Number of stacked LSTM layers (default is 1).
When you call the LSTM, input must be a tensor of shape (sequence_length, batch_size, input_size). The output is a tuple containing:
output: All hidden states for each time step.(h_n, c_n): The hidden and cell states for the last time step.
python
import torch import torch.nn as nn lstm = nn.LSTM(input_size=10, hidden_size=20, num_layers=2) input = torch.randn(5, 3, 10) # seq_len=5, batch=3, input_size=10 output, (h_n, c_n) = lstm(input)
Example
This example creates a 2-layer LSTM with input size 4 and hidden size 3. It passes a random input tensor of shape (6, 1, 4) representing 6 time steps, batch size 1, and 4 features. The code prints the output shape and the last hidden state shape.
python
import torch import torch.nn as nn # Define LSTM lstm = nn.LSTM(input_size=4, hidden_size=3, num_layers=2) # Random input: seq_len=6, batch=1, input_size=4 input = torch.randn(6, 1, 4) # Forward pass output, (h_n, c_n) = lstm(input) print('Output shape:', output.shape) # (seq_len, batch, hidden_size) print('Hidden state shape:', h_n.shape) # (num_layers, batch, hidden_size)
Output
Output shape: torch.Size([6, 1, 3])
Hidden state shape: torch.Size([2, 1, 3])
Common Pitfalls
- Input shape mismatch: The input must be 3D with shape
(seq_len, batch, input_size). Using(batch, seq_len, input_size)will cause errors or wrong results. - Hidden state initialization: If not provided, hidden and cell states default to zeros. For stateful LSTMs, you must pass them explicitly.
- Batch size consistency: Batch size must be consistent across input and hidden states.
python
import torch import torch.nn as nn lstm = nn.LSTM(input_size=4, hidden_size=3) # Wrong input shape (batch first instead of seq first) wrong_input = torch.randn(1, 6, 4) # batch=1, seq_len=6, input_size=4 try: output, (h_n, c_n) = lstm(wrong_input) except Exception as e: print('Error:', e) # Correct input shape correct_input = wrong_input.transpose(0, 1) # (seq_len, batch, input_size) output, (h_n, c_n) = lstm(correct_input) print('Output shape with correct input:', output.shape)
Output
Error: Expected 3-dimensional input for 3-dimensional weight [12, 4, 3], but got input of size [1, 6, 4]
Output shape with correct input: torch.Size([6, 1, 3])
Quick Reference
Remember these tips when using nn.LSTM:
- Input shape:
(seq_len, batch, input_size) - Output shape:
(seq_len, batch, hidden_size) - Hidden states shape:
(num_layers, batch, hidden_size) - Default hidden and cell states are zeros if not provided
- Use
batch_first=Trueif you prefer input shape(batch, seq_len, input_size)
Key Takeaways
Initialize nn.LSTM with input_size, hidden_size, and num_layers to define the model.
Input tensors must have shape (sequence_length, batch_size, input_size) unless batch_first=True is set.
The LSTM output includes all hidden states and the final hidden and cell states.
Common errors come from wrong input shapes or mismatched batch sizes.
Use batch_first=True to work with input shape (batch, seq_len, input_size) if preferred.