How to Build an RNN in PyTorch: Simple Guide with Example
To build an RNN in PyTorch, use the
torch.nn.RNN module inside a custom nn.Module class. Define the RNN layer, pass input sequences through it, and use the output for predictions or further layers.Syntax
The basic syntax to create an RNN layer in PyTorch is using torch.nn.RNN. You specify the input size, hidden size, number of layers, and other options.
- input_size: Number of features in the input.
- hidden_size: Number of features in the hidden state.
- num_layers: Number of stacked RNN layers.
- batch_first: If True, input shape is (batch, seq, feature).
You then pass input tensors to the RNN layer to get output and hidden states.
python
rnn = torch.nn.RNN(input_size=10, hidden_size=20, num_layers=1, batch_first=True) input = torch.randn(5, 3, 10) # batch=5, seq_len=3, features=10 output, hidden = rnn(input)
Example
This example shows a simple RNN model class in PyTorch that takes input sequences and outputs predictions. It demonstrates defining the RNN layer, forward pass, and running a dummy input through the model.
python
import torch import torch.nn as nn class SimpleRNN(nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.rnn = nn.RNN(input_size, hidden_size, batch_first=True) self.fc = nn.Linear(hidden_size, output_size) def forward(self, x): out, hidden = self.rnn(x) # out shape: (batch, seq_len, hidden_size) out = out[:, -1, :] # take last time step output out = self.fc(out) # map to output size return out # Parameters input_size = 4 hidden_size = 8 output_size = 1 batch_size = 2 seq_len = 5 # Create model model = SimpleRNN(input_size, hidden_size, output_size) # Dummy input: batch of 2 sequences, each with 5 time steps and 4 features input_tensor = torch.randn(batch_size, seq_len, input_size) # Forward pass output = model(input_tensor) print(output)
Output
tensor([[ 0.0185],
[-0.0517]], grad_fn=<AddmmBackward0>)
Common Pitfalls
Common mistakes when building RNNs in PyTorch include:
- Not setting
batch_first=Truewhen input shape is (batch, seq, feature), causing shape mismatches. - Forgetting to take the last time step output if you want a single prediction per sequence.
- Not initializing hidden states when needed (though PyTorch defaults to zeros).
- Mixing up input dimensions or forgetting to match
input_sizewith feature size.
python
import torch import torch.nn as nn # Wrong: input shape without batch_first but batch_first=True rnn_wrong = nn.RNN(input_size=3, hidden_size=5, batch_first=True) input_wrong = torch.randn(10, 3) # Missing batch dimension try: output, hidden = rnn_wrong(input_wrong) except Exception as e: print(f"Error: {e}") # Right: input shape with batch dimension rnn_right = nn.RNN(input_size=3, hidden_size=5, batch_first=True) input_right = torch.randn(2, 4, 3) # batch=2, seq_len=4, features=3 output, hidden = rnn_right(input_right) print("Output shape:", output.shape)
Output
Error: expected 3-dimensional input for 3-dimensional weight [5, 3], but got 2-dimensional input of size [10, 3]
Output shape: torch.Size([2, 4, 5])
Quick Reference
Remember these tips when building RNNs in PyTorch:
- Use
nn.RNNwith correctinput_sizeandhidden_size. - Set
batch_first=Trueif your input shape is (batch, seq, feature). - Extract the last time step output for sequence-level predictions.
- Initialize hidden states if you want custom values; otherwise, PyTorch uses zeros.
Key Takeaways
Use torch.nn.RNN inside a custom nn.Module to build RNNs in PyTorch.
Set batch_first=True if your input shape is (batch, sequence, features).
Extract the last time step output for single prediction per sequence.
Match input_size with your input feature dimension to avoid shape errors.
Initialize hidden states only if you need custom starting states; default is zeros.