Bird
Raised Fist0
NLPml~20 mins

Answer span extraction in NLP - ML Experiment: Train & Evaluate

Choose your learning style10 modes available

Start learning this pattern below

Jump into concepts and practice - no test required

or
Recommended
Test this pattern10 questions across easy, medium, and hard to know if this pattern is strong
Experiment - Answer span extraction
Problem:We want to build a model that finds the exact answer span in a paragraph given a question. Currently, the model predicts start and end positions of the answer in the text.
Current Metrics:Training loss: 0.15, Training accuracy (exact match): 85%, Validation loss: 0.40, Validation accuracy (exact match): 65%
Issue:The model is overfitting: training accuracy is high but validation accuracy is much lower.
Your Task
Reduce overfitting so that validation accuracy improves to at least 75%, while keeping training accuracy below 90%.
You cannot change the dataset or add more data.
You must keep the same model architecture (a simple BiLSTM with start/end classifiers).
Hint 1
Hint 2
Hint 3
Solution
NLP
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

class AnswerSpanModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, dropout_rate=0.3):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.bilstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.dropout = nn.Dropout(dropout_rate)
        self.start_classifier = nn.Linear(hidden_dim * 2, 1)
        self.end_classifier = nn.Linear(hidden_dim * 2, 1)

    def forward(self, x):
        emb = self.embedding(x)
        lstm_out, _ = self.bilstm(emb)
        dropped = self.dropout(lstm_out)
        start_logits = self.start_classifier(dropped).squeeze(-1)
        end_logits = self.end_classifier(dropped).squeeze(-1)
        return start_logits, end_logits

# Assume train_loader and val_loader are defined elsewhere

model = AnswerSpanModel(vocab_size=10000, embedding_dim=100, hidden_dim=64, dropout_rate=0.3)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

best_val_acc = 0
patience = 3
trigger_times = 0

for epoch in range(20):
    model.train()
    for inputs, start_positions, end_positions in train_loader:
        optimizer.zero_grad()
        start_logits, end_logits = model(inputs)
        loss_start = criterion(start_logits, start_positions)
        loss_end = criterion(end_logits, end_positions)
        loss = loss_start + loss_end
        loss.backward()
        optimizer.step()

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, start_positions, end_positions in val_loader:
            start_logits, end_logits = model(inputs)
            pred_start = start_logits.argmax(dim=1)
            pred_end = end_logits.argmax(dim=1)
            correct += ((pred_start == start_positions) & (pred_end == end_positions)).sum().item()
            total += inputs.size(0)
    val_acc = correct / total * 100
    print(f"Epoch {epoch+1}, Validation Exact Match Accuracy: {val_acc:.2f}%")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        trigger_times = 0
    else:
        trigger_times += 1
        if trigger_times >= patience:
            print("Early stopping triggered")
            break
Added dropout layer with rate 0.3 after BiLSTM to reduce overfitting.
Lowered learning rate from 0.01 to 0.001 for better convergence.
Implemented early stopping with patience of 3 epochs to avoid overtraining.
Results Interpretation

Before: Training accuracy 85%, Validation accuracy 65% (overfitting)

After: Training accuracy 88%, Validation accuracy 77% (reduced overfitting)

Adding dropout and early stopping helps the model generalize better, reducing the gap between training and validation accuracy.
Bonus Experiment
Try using a pretrained language model like BERT for answer span extraction to improve accuracy.
💡 Hint
Use Hugging Face transformers library and fine-tune a BERT model on the same dataset.

Practice

(1/5)
1. What is the main goal of answer span extraction in NLP?
easy
A. To generate new text based on a prompt
B. To find the exact part of text that answers a question
C. To summarize long documents into short sentences
D. To translate text from one language to another

Solution

  1. Step 1: Understand the purpose of answer span extraction

    Answer span extraction focuses on locating the exact segment in a text that directly answers a question.
  2. Step 2: Compare with other NLP tasks

    Unlike translation, summarization, or text generation, answer span extraction pinpoints a specific text span as the answer.
  3. Final Answer:

    To find the exact part of text that answers a question -> Option B
  4. Quick Check:

    Answer span extraction = find exact answer span [OK]
Hint: Answer span extraction locates exact text answers [OK]
Common Mistakes:
  • Confusing answer span extraction with translation
  • Thinking it summarizes text instead of extracting spans
  • Assuming it generates new text
2. Which of the following is the correct way to represent the start and end positions for answer span extraction in code?
easy
A. start_index and end_index as integers
B. start_word and end_word as strings
C. start_time and end_time as floats
D. start_char and end_char as booleans

Solution

  1. Step 1: Identify typical data types for positions

    Positions in text are usually represented by integer indices marking start and end locations.
  2. Step 2: Evaluate options

    Strings or booleans do not represent positions well; floats for time are unrelated to text spans.
  3. Final Answer:

    start_index and end_index as integers -> Option A
  4. Quick Check:

    Positions = integer indices [OK]
Hint: Positions in text are integer indices [OK]
Common Mistakes:
  • Using strings instead of integer indices
  • Confusing character positions with time values
  • Using booleans for position markers
3. Given the text: 'The cat sat on the mat.' and predicted start index = 1, end index = 4, what is the extracted answer span?
medium
A. 'cat sat on'
B. 'sat on the'
C. 'on the mat'
D. 'The cat sat'

Solution

  1. Step 1: Identify tokens and their indices

    Tokenizing the sentence: ['The'(0), 'cat'(1), 'sat'(2), 'on'(3), 'the'(4), 'mat.'(5)]. The indices given (1 to 4) refer to 0-based token positions.
  2. Step 2: Extract tokens from start to end index

    In standard extraction, take tokens[start:end] (end exclusive): tokens[1:4] = ['cat'(1), 'sat'(2), 'on'(3)] = 'cat sat on'.
  3. Final Answer:

    'cat sat on' -> Option A
  4. Quick Check:

    Extract tokens from start to end index = 'cat sat on' [OK]
Hint: Match indices to tokens carefully [OK]
Common Mistakes:
  • Confusing character indices with token indices
  • Off-by-one errors in slicing
  • Ignoring punctuation in tokens
4. You have a model that predicts start and end indices for answer spans but sometimes the end index is smaller than the start index. What is the best way to fix this bug?
medium
A. Ignore the prediction and return an empty answer
B. Always set end index to start index plus one
C. Swap the start and end indices if end < start
D. Use only the start index as the answer

Solution

  1. Step 1: Understand the problem with indices

    End index smaller than start index is invalid because answer spans must go forward in text.
  2. Step 2: Choose a fix that preserves valid spans

    Swapping start and end indices corrects the order and keeps the predicted span meaningful.
  3. Final Answer:

    Swap the start and end indices if end < start -> Option C
  4. Quick Check:

    Fix invalid spans by swapping indices [OK]
Hint: Swap indices if end < start to fix spans [OK]
Common Mistakes:
  • Ignoring invalid spans instead of fixing
  • Forcing fixed span length blindly
  • Using only one index loses answer context
5. In a question-answering system, the model outputs start logits and end logits for each token. How should you combine these to find the best answer span?
hard
A. Choose random start and end indices
B. Pick the token with the highest start logit only
C. Pick the token with the highest end logit only
D. Find the pair of start and end indices with the highest sum of start and end logits where start ≤ end

Solution

  1. Step 1: Understand logits for start and end tokens

    Start and end logits represent scores for each token being the start or end of the answer span.
  2. Step 2: Combine logits to find best span

    We look for the pair (start, end) with the highest combined score, ensuring start ≤ end to form a valid span.
  3. Final Answer:

    Find the pair of start and end indices with the highest sum of start and end logits where start ≤ end -> Option D
  4. Quick Check:

    Combine start and end logits to find best span [OK]
Hint: Sum start and end logits, ensure start ≤ end [OK]
Common Mistakes:
  • Ignoring end logits and using start only
  • Choosing invalid spans where end < start
  • Picking random indices without scores