When we fine-tune BERT for classification, the main goal is to correctly label text into categories. The key metrics to check are accuracy, precision, recall, and F1 score. Accuracy tells us overall how many texts were labeled right. Precision shows how many predicted labels were actually correct. Recall tells us how many true labels we found out of all real ones. F1 score balances precision and recall, which is important when classes are uneven or mistakes have different costs.
BERT fine-tuning for classification in NLP - Model Metrics & Evaluation
Start learning this pattern below
Jump into concepts and practice - no test required
| Predicted Positive | Predicted Negative |
|--------------------|--------------------|
| True Positive (TP): 80 | False Negative (FN): 20 |
| False Positive (FP): 10 | True Negative (TN): 90 |
Total samples = TP + FP + TN + FN = 80 + 10 + 90 + 20 = 200
Precision = TP / (TP + FP) = 80 / (80 + 10) = 0.89
Recall = TP / (TP + FN) = 80 / (80 + 20) = 0.80
F1 Score = 2 * (Precision * Recall) / (Precision + Recall) = 2 * (0.89 * 0.80) / (0.89 + 0.80) ≈ 0.84
Imagine BERT is classifying emails as spam or not spam.
- High Precision: Few good emails are wrongly marked as spam. This means users don't miss important emails. But some spam might get through.
- High Recall: Most spam emails are caught. But some good emails might be wrongly marked as spam, annoying users.
Depending on what matters more, we adjust the model or threshold. For spam, usually high precision is preferred to avoid losing good emails.
Good: Accuracy above 85%, Precision and Recall above 80%, and F1 score balanced near 0.8 or higher. This means the model predicts well and finds most true labels without many mistakes.
Bad: Accuracy near 50% (like random guessing), Precision or Recall below 50%, or very unbalanced F1 score (e.g., high precision but very low recall). This means the model is not reliable or misses many true cases.
- Accuracy paradox: High accuracy can be misleading if classes are imbalanced. For example, if 90% of texts are class A, predicting all as A gives 90% accuracy but no real learning.
- Data leakage: If test data leaks into training, metrics look too good but model fails in real use.
- Overfitting: Very high training accuracy but low test accuracy means model memorized training data, not learned general patterns.
Your BERT model has 98% accuracy but only 12% recall on the positive class (e.g., fraud detection). Is this good for production? Why or why not?
Answer: No, it is not good. The model misses 88% of actual positive cases, which is very risky in fraud detection. High accuracy is misleading because most data is negative. You need to improve recall to catch more fraud cases.
Practice
Solution
Step 1: Understand BERT's pretraining
BERT is pretrained on general language tasks and needs adjustment for specific tasks like classification.Step 2: Purpose of fine-tuning
Fine-tuning adapts BERT's learned language understanding to classify categories in your dataset.Final Answer:
To adapt BERT's knowledge to classify specific categories in your data -> Option AQuick Check:
Fine-tuning = adapt BERT for classification [OK]
- Thinking fine-tuning trains BERT from zero
- Confusing fine-tuning with model compression
- Assuming BERT outputs images
Solution
Step 1: Identify proper BERT tokenization method
BERT uses tokenizer.encode_plus to convert text into token IDs and attention masks.Step 2: Compare options
tokens = tokenizer.encode_plus(text, return_tensors='pt') uses encode_plus with return_tensors='pt' for PyTorch tensors, which is correct for BERT input.Final Answer:
tokens = tokenizer.encode_plus(text, return_tensors='pt') -> Option BQuick Check:
Use encode_plus for BERT tokenization [OK]
- Using simple split instead of tokenizer
- Only tokenizing without encoding IDs
- Not returning tensors for model input
print(predictions.argmax(dim=1)) if the model predicts logits [[2.0, 1.0], [0.5, 1.5]] for two samples?logits = torch.tensor([[2.0, 1.0], [0.5, 1.5]]) predictions = logits print(predictions.argmax(dim=1))
Solution
Step 1: Understand argmax(dim=1)
Argmax along dim=1 finds the index of max value in each row (sample).Step 2: Calculate argmax for each sample
First row: max is 2.0 at index 0; second row: max is 1.5 at index 1.Final Answer:
tensor([0, 1]) -> Option DQuick Check:
Argmax per row = [0, 1] [OK]
- Confusing dim=0 with dim=1
- Mixing up indices and values
- Expecting values instead of indices
TypeError: forward() missing 1 required positional argument: 'labels'. What is the likely fix?outputs = model(input_ids, attention_mask) loss = outputs.loss loss.backward()
Solution
Step 1: Understand error cause
The model expects labels to compute loss but they are missing in the call.Step 2: Fix by passing labels
Include labels argument in model call to get loss: model(input_ids, attention_mask, labels=labels).Final Answer:
Pass labels to the model call: model(input_ids, attention_mask, labels=labels) -> Option AQuick Check:
Missing labels argument causes loss error [OK]
- Ignoring the missing labels argument
- Removing backward call instead of fixing input
- Changing variable names incorrectly
Solution
Step 1: Identify overfitting risks
Small datasets can cause the model to memorize instead of generalize.Step 2: Apply regularization techniques
Using a small learning rate and dropout helps the model learn smoothly and avoid overfitting.Final Answer:
Use a small learning rate and add dropout layers -> Option CQuick Check:
Small LR + dropout reduces overfitting [OK]
- Training longer without regularization
- Skipping tokenization
- Removing classification head incorrectly
