0
0
TensorFlowml~20 mins

Confusion matrix visualization in TensorFlow - ML Experiment: Train & Evaluate

Choose your learning style9 modes available
Experiment - Confusion matrix visualization
Problem:You have trained a classification model using TensorFlow, but you want to better understand how well it predicts each class by visualizing the confusion matrix.
Current Metrics:Training accuracy: 85%, Validation accuracy: 82%
Issue:The model's overall accuracy is known, but it is unclear which classes are confused with each other, making it hard to improve the model.
Your Task
Create and visualize a confusion matrix for the validation dataset predictions to identify which classes the model confuses most.
Use TensorFlow and matplotlib only.
Do not retrain or change the model architecture.
Use the validation dataset predictions and true labels.
Hint 1
Hint 2
Hint 3
Hint 4
Solution
TensorFlow
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

# Assume model, X_val, y_val are predefined
# For demonstration, create dummy data and model predictions
num_classes = 3

# Dummy true labels for validation set
y_true = np.array([0, 1, 2, 2, 1, 0, 1, 2, 0, 1])

# Dummy predicted labels from model
y_pred = np.array([0, 2, 2, 2, 1, 0, 0, 2, 0, 1])

# Compute confusion matrix
cm = tf.math.confusion_matrix(y_true, y_pred, num_classes=num_classes).numpy()

# Normalize confusion matrix by true label counts
cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

# Plot confusion matrix
plt.figure(figsize=(6, 6))
plt.imshow(cm_norm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Normalized Confusion Matrix')
plt.colorbar()

classes = [f'Class {i}' for i in range(num_classes)]
plt.xticks(np.arange(num_classes), classes, rotation=45)
plt.yticks(np.arange(num_classes), classes)

# Loop over data dimensions and create text annotations.
thresh = cm_norm.max() / 2.
for i in range(num_classes):
    for j in range(num_classes):
        plt.text(j, i, f'{cm[i, j]} ({cm_norm[i, j]:.2f})',
                 horizontalalignment='center',
                 color='white' if cm_norm[i, j] > thresh else 'black')

plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.tight_layout()
plt.show()
Added code to compute confusion matrix using tf.math.confusion_matrix.
Normalized the confusion matrix to show proportions per true class.
Visualized the matrix with matplotlib using color coding and annotations.
Added axis labels and color bar for better understanding.
Results Interpretation

Before: Only overall accuracy was known (Training: 85%, Validation: 82%).

After: Confusion matrix visualization reveals which classes are confused, e.g., Class 1 is sometimes predicted as Class 0 or 2.

Visualizing the confusion matrix helps understand detailed model performance beyond accuracy, showing specific class prediction errors to guide improvements.
Bonus Experiment
Try creating a confusion matrix heatmap using seaborn library for enhanced visualization.
💡 Hint
Use seaborn.heatmap with annotations and a suitable color palette for clearer visuals.