0
0
TensorFlowml~8 mins

Multi-class classification model in TensorFlow - Model Metrics & Evaluation

Choose your learning style9 modes available
Metrics & Evaluation - Multi-class classification model
Which metric matters for Multi-class classification and WHY

In multi-class classification, the model predicts one class out of many possible classes. The key metrics to check are accuracy, precision, recall, and F1-score for each class. Accuracy tells how often the model is right overall. Precision shows how many predicted items for a class are correct. Recall shows how many actual items of a class the model found. F1-score balances precision and recall. These metrics help understand if the model is good at distinguishing all classes well.

Confusion matrix for Multi-class classification

A confusion matrix shows how predictions match actual classes. For 3 classes (A, B, C), it looks like this:

      | Predicted A | Predicted B | Predicted C |
      |-------------|-------------|-------------|
      |     50      |      2      |      3      | Actual A
      |      4      |     45      |      1      | Actual B
      |      5      |      2      |     43      | Actual C
    

Each row sums to total samples of that actual class. Diagonal numbers are correct predictions (True Positives for each class). Off-diagonal numbers are errors.

Precision vs Recall tradeoff with examples

Imagine a model classifying animals: cats, dogs, and rabbits.

  • High precision for cats: When the model says "cat," it is usually right. This is good if you want to avoid wrongly labeling dogs or rabbits as cats.
  • High recall for cats: The model finds most of the actual cats. This is important if missing a cat is costly, like in a pet shelter.

Improving precision may lower recall and vice versa. The F1-score helps balance both.

What "good" vs "bad" metric values look like

For a balanced multi-class problem:

  • Good: Accuracy above 85%, precision and recall above 80% for all classes, and F1-scores close to precision and recall.
  • Bad: Accuracy below 60%, large differences in precision or recall between classes (e.g., 90% for one class but 30% for another), or very low F1-scores indicating poor balance.

Good metrics mean the model predicts all classes well. Bad metrics mean the model struggles with some classes or overall.

Common pitfalls in metrics for multi-class classification
  • Accuracy paradox: High accuracy can be misleading if classes are imbalanced. For example, if 90% of data is class A, predicting only A gives 90% accuracy but poor performance on other classes.
  • Ignoring per-class metrics: Overall accuracy hides if some classes are poorly predicted.
  • Data leakage: If test data leaks into training, metrics look unrealistically good.
  • Overfitting: Very high training accuracy but low test accuracy means the model memorizes training data and won't generalize.
Self-check question

Your multi-class model has 92% accuracy but the recall for class B is 40%. Is it good for production?

Answer: No, because the model misses many actual class B samples. Even with high overall accuracy, poor recall on a class means the model is not reliable for that class. You should improve recall for class B before production.

Key Result
In multi-class classification, balanced precision, recall, and F1-score per class are key to ensure the model predicts all classes well, beyond just overall accuracy.