0
0
PyTorchml~5 mins

Hidden state management in PyTorch

Choose your learning style9 modes available
Introduction

Hidden state management helps keep track of information over time in models like RNNs. It lets the model remember past data to make better predictions.

When processing sentences word by word to understand context.
When analyzing time series data like stock prices over days.
When generating text step by step, like in chatbots.
When recognizing speech sounds in a sequence.
When predicting the next move in a game based on past moves.
Syntax
PyTorch
hidden = torch.zeros(num_layers, batch_size, hidden_size)
hidden = model.init_hidden(batch_size)
output, hidden = model(input, hidden)

The hidden state is usually a tensor that holds past information.

You pass the hidden state to the model and get an updated hidden state back.

Examples
Initialize hidden state with zeros for 1 layer, batch size 1, and hidden size 10.
PyTorch
hidden = torch.zeros(1, 1, 10)
output, hidden = rnn(input, hidden)
Use a model method to create hidden state for batch size 5, then run input through model.
PyTorch
hidden = model.init_hidden(batch_size=5)
output, hidden = model(input, hidden)
Passing None lets PyTorch initialize hidden state automatically.
PyTorch
hidden = None
output, hidden = rnn(input, hidden)
Sample Model

This code creates a simple RNN model that takes sequences of 3 numbers, processes them, and outputs 2 numbers per sequence. It shows how to initialize and pass the hidden state.

PyTorch
import torch
import torch.nn as nn

class SimpleRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden):
        out, hidden = self.rnn(x, hidden)
        out = self.fc(out[:, -1, :])
        return out, hidden

    def init_hidden(self, batch_size):
        return torch.zeros(1, batch_size, self.hidden_size)

# Parameters
input_size = 3
hidden_size = 5
output_size = 2
batch_size = 4
seq_len = 6

# Create model
model = SimpleRNN(input_size, hidden_size, output_size)

# Random input: batch_size sequences, each of length seq_len, each element size input_size
inputs = torch.randn(batch_size, seq_len, input_size)

# Initialize hidden state
hidden = model.init_hidden(batch_size)

# Forward pass
outputs, hidden = model(inputs, hidden)

print("Output shape:", outputs.shape)
print("Output values:", outputs)
print("Hidden state shape:", hidden.shape)
OutputSuccess
Important Notes

Always match the hidden state shape to (num_layers, batch_size, hidden_size).

Keep hidden state between batches if you want to remember across sequences.

Reset hidden state to zeros when starting fresh sequences to avoid mixing data.

Summary

Hidden state stores past information in sequence models.

Initialize hidden state before feeding data to the model.

Pass hidden state along with input to keep track of sequence context.