Bidirectional RNNs are used to understand sequences from both past and future context. The key metrics to check are accuracy for classification tasks and loss for sequence prediction. For tasks like speech recognition or text tagging, precision, recall, and F1 score are important to measure how well the model predicts each class, especially when classes are imbalanced.
Bidirectional RNNs in PyTorch - Model Metrics & Evaluation
Start learning this pattern below
Jump into concepts and practice - no test required
Suppose a bidirectional RNN classifies words into two classes: Positive (P) and Negative (N). Here is a confusion matrix:
| Predicted P | Predicted N |
|-------------|-------------|
| True P: 50 | 10 |
| True N: 5 | 35 |
Total samples = 50 + 10 + 5 + 35 = 100
From this matrix:
- Precision = TP / (TP + FP) = 50 / (50 + 5) = 0.91
- Recall = TP / (TP + FN) = 50 / (50 + 10) = 0.83
- F1 Score = 2 * (Precision * Recall) / (Precision + Recall) ≈ 0.87
In bidirectional RNNs, depending on the task, you might want to balance precision and recall differently.
- High Precision: Useful when false positives are costly. For example, in medical diagnosis, wrongly predicting a disease when it is not present can cause unnecessary stress and treatment.
- High Recall: Important when missing a positive case is dangerous. For example, in fraud detection, missing a fraud case (false negative) is worse than flagging a normal case.
Bidirectional RNNs help by using context from both directions, which can improve both precision and recall compared to unidirectional models.
For a bidirectional RNN on a balanced classification task:
- Good: Accuracy above 85%, Precision and Recall above 80%, F1 score above 0.8.
- Bad: Accuracy below 60%, Precision or Recall below 50%, F1 score below 0.5.
Low precision means many false alarms. Low recall means many misses. Both reduce usefulness.
- Accuracy Paradox: High accuracy can be misleading if classes are imbalanced. For example, if 90% of data is class A, predicting all A gives 90% accuracy but zero recall for class B.
- Data Leakage: If future information leaks into training, metrics look better but model fails in real use.
- Overfitting: Very low training loss but high validation loss means model memorizes training data and won't generalize.
- Ignoring Sequence Length: Metrics averaged over sequences of different lengths can hide poor performance on longer sequences.
Your bidirectional RNN model has 98% accuracy but only 12% recall on the positive class (e.g., fraud). Is it good for production?
Answer: No, it is not good. The model misses 88% of positive cases, which is dangerous for fraud detection. High accuracy is misleading because most data is negative. You need to improve recall to catch more fraud cases.
Practice
bidirectional RNN compared to a standard RNN?Solution
Step 1: Understand standard RNN processing
Standard RNNs process sequences only in the forward direction, so they only see past context.Step 2: Analyze bidirectional RNN behavior
Bidirectional RNNs process sequences both forward and backward, capturing past and future context.Final Answer:
It processes the input sequence in both forward and backward directions to capture full context. -> Option AQuick Check:
Bidirectional = forward + backward context [OK]
- Thinking bidirectional reduces parameters
- Assuming it only reads backward
- Confusing with convolutional layers
Solution
Step 1: Recall PyTorch GRU parameters
Thebidirectionalparameter is a boolean that enables bidirectional processing.Step 2: Identify correct syntax
Only torch.nn.GRU(input_size=10, hidden_size=20, bidirectional=True) usesbidirectional=True, which is the correct PyTorch syntax.Final Answer:
torch.nn.GRU(input_size=10, hidden_size=20, bidirectional=True) -> Option BQuick Check:
bidirectional=True enables two directions [OK]
- Using invalid parameter names like 'direction' or 'two_directions'
- Setting bidirectional=False by mistake
- Confusing input_size and hidden_size
rnn = torch.nn.RNN(input_size=5, hidden_size=3, bidirectional=True, batch_first=True) input = torch.randn(4, 7, 5) # batch=4, seq_len=7, input_size=5 output, _ = rnn(input)
Solution
Step 1: Understand output shape of bidirectional RNN
Output shape is (batch_size, seq_len, hidden_size * num_directions). Here, num_directions=2.Step 2: Calculate output shape
hidden_size=3, so output last dimension = 3 * 2 = 6. Batch=4, seq_len=7, so output shape = [4, 7, 6].Final Answer:
[4, 7, 6] -> Option CQuick Check:
Output last dim = hidden_size * 2 [OK]
- Forgetting to multiply hidden_size by 2
- Mixing batch and sequence dimensions
- Assuming output shape matches input exactly
rnn = torch.nn.RNN(input_size=8, hidden_size=4, bidirectional=True) input = torch.randn(5, 10, 8) output, hidden = rnn(input)
What is the likely cause of the error?
Solution
Step 1: Check default input shape for PyTorch RNN
By default, PyTorch RNN expects input shape (seq_len, batch, input_size) unless batch_first=True is set.Step 2: Analyze given input shape
Input shape is (5, 10, 8) which is (batch, seq_len, input_size), but batch_first=True is not set, causing mismatch.Final Answer:
Input tensor shape should have batch_first=True or be transposed to (seq_len, batch, input_size). -> Option AQuick Check:
Default RNN input shape = (seq_len, batch, input_size) [OK]
- Assuming bidirectional disables shape rules
- Thinking hidden_size must match input_size
- Passing 2D input instead of 3D
Solution
Step 1: Understand variable-length sequence handling
PyTorch requires packing padded sequences to efficiently process variable-length inputs in RNNs.Step 2: Apply packing with bidirectional LSTM
Usepack_padded_sequencebefore feeding to LSTM withbidirectional=True, then unpack withpad_packed_sequence.Final Answer:
Use pack_padded_sequence before the LSTM and pad_packed_sequence after, with batch_first=True and bidirectional=True set. -> Option DQuick Check:
Pack sequences for variable length + bidirectional LSTM [OK]
- Ignoring packing and feeding padded sequences directly
- Disabling bidirectional for variable lengths
- Manually reversing sequences instead of using bidirectional flag
