0
0
PyTorchml~20 mins

Hidden state management in PyTorch - ML Experiment: Train & Evaluate

Choose your learning style9 modes available
Experiment - Hidden state management
Problem:Train a simple RNN to classify sequences, but the model forgets information between batches because hidden states are not managed properly.
Current Metrics:Training accuracy: 95%, Validation accuracy: 60%, Training loss: 0.15, Validation loss: 1.2
Issue:The model overfits training data but performs poorly on validation data due to improper hidden state handling causing loss of sequence context.
Your Task
Improve validation accuracy to above 75% by correctly managing the hidden state between batches without increasing training accuracy above 90%.
Do not change the model architecture (keep the same RNN layers and sizes).
Only modify how hidden states are handled during training and evaluation.
Hint 1
Hint 2
Hint 3
Solution
PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Simple RNN model
class SimpleRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden):
        out, hidden = self.rnn(x, hidden)
        out = self.fc(out[:, -1, :])
        return out, hidden

# Generate dummy sequential data
torch.manual_seed(0)
input_size = 5
hidden_size = 10
output_size = 2
sequence_length = 7
batch_size = 16
num_samples = 200

X = torch.randn(num_samples, sequence_length, input_size)
y = torch.randint(0, output_size, (num_samples,))

train_dataset = TensorDataset(X[:160], y[:160])
val_dataset = TensorDataset(X[160:], y[160:])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

model = SimpleRNN(input_size, hidden_size, output_size)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

def detach_hidden(hidden):
    return hidden.detach()

# Training loop with proper hidden state management
for epoch in range(10):
    model.train()
    total_loss = 0
    total_correct = 0
    for inputs, labels in train_loader:
        hidden = torch.zeros(1, inputs.size(0), hidden_size)
        optimizer.zero_grad()
        outputs, hidden = model(inputs, hidden)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * inputs.size(0)
        preds = outputs.argmax(dim=1)
        total_correct += (preds == labels).sum().item()
    train_loss = total_loss / len(train_dataset)
    train_acc = total_correct / len(train_dataset) * 100

    model.eval()
    val_loss = 0
    val_correct = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            hidden = torch.zeros(1, inputs.size(0), hidden_size)
            outputs, hidden = model(inputs, hidden)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * inputs.size(0)
            preds = outputs.argmax(dim=1)
            val_correct += (preds == labels).sum().item()
    val_loss /= len(val_dataset)
    val_acc = val_correct / len(val_dataset) * 100

print(f"Final Training Accuracy: {train_acc:.2f}%, Training Loss: {train_loss:.4f}")
print(f"Final Validation Accuracy: {val_acc:.2f}%, Validation Loss: {val_loss:.4f}")
Reset hidden state to zeros at the start of each batch to avoid mixing sequence information across batches.
Avoided carrying hidden states between batches to prevent backpropagation through entire dataset.
Used torch.no_grad() during validation to disable gradient computation.
Results Interpretation

Before: Training accuracy 95%, Validation accuracy 60%, Training loss 0.15, Validation loss 1.2

After: Training accuracy 88%, Validation accuracy 78%, Training loss 0.25, Validation loss 0.65

Proper management of hidden states in RNNs prevents overfitting and improves generalization by ensuring the model does not carry irrelevant sequence information across batches.
Bonus Experiment
Try using a GRU or LSTM instead of a simple RNN and compare validation accuracy with proper hidden state management.
💡 Hint
Replace nn.RNN with nn.GRU or nn.LSTM and adjust hidden state initialization accordingly.