Hidden state management helps keep track of information over time in models like RNNs. It lets the model remember past data to make better predictions.
0
0
Hidden state management in PyTorch
Introduction
When processing sentences word by word to understand context.
When analyzing time series data like stock prices over days.
When generating text step by step, like in chatbots.
When recognizing speech sounds in a sequence.
When predicting the next move in a game based on past moves.
Syntax
PyTorch
hidden = torch.zeros(num_layers, batch_size, hidden_size)
hidden = model.init_hidden(batch_size)
output, hidden = model(input, hidden)The hidden state is usually a tensor that holds past information.
You pass the hidden state to the model and get an updated hidden state back.
Examples
Initialize hidden state with zeros for 1 layer, batch size 1, and hidden size 10.
PyTorch
hidden = torch.zeros(1, 1, 10) output, hidden = rnn(input, hidden)
Use a model method to create hidden state for batch size 5, then run input through model.
PyTorch
hidden = model.init_hidden(batch_size=5) output, hidden = model(input, hidden)
Passing None lets PyTorch initialize hidden state automatically.
PyTorch
hidden = None output, hidden = rnn(input, hidden)
Sample Model
This code creates a simple RNN model that takes sequences of 3 numbers, processes them, and outputs 2 numbers per sequence. It shows how to initialize and pass the hidden state.
PyTorch
import torch import torch.nn as nn class SimpleRNN(nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.hidden_size = hidden_size self.rnn = nn.RNN(input_size, hidden_size, batch_first=True) self.fc = nn.Linear(hidden_size, output_size) def forward(self, x, hidden): out, hidden = self.rnn(x, hidden) out = self.fc(out[:, -1, :]) return out, hidden def init_hidden(self, batch_size): return torch.zeros(1, batch_size, self.hidden_size) # Parameters input_size = 3 hidden_size = 5 output_size = 2 batch_size = 4 seq_len = 6 # Create model model = SimpleRNN(input_size, hidden_size, output_size) # Random input: batch_size sequences, each of length seq_len, each element size input_size inputs = torch.randn(batch_size, seq_len, input_size) # Initialize hidden state hidden = model.init_hidden(batch_size) # Forward pass outputs, hidden = model(inputs, hidden) print("Output shape:", outputs.shape) print("Output values:", outputs) print("Hidden state shape:", hidden.shape)
OutputSuccess
Important Notes
Always match the hidden state shape to (num_layers, batch_size, hidden_size).
Keep hidden state between batches if you want to remember across sequences.
Reset hidden state to zeros when starting fresh sequences to avoid mixing data.
Summary
Hidden state stores past information in sequence models.
Initialize hidden state before feeding data to the model.
Pass hidden state along with input to keep track of sequence context.