0
0
TensorFlowml~20 mins

Confusion matrix analysis in TensorFlow - ML Experiment: Train & Evaluate

Choose your learning style9 modes available
Experiment - Confusion matrix analysis
Problem:You trained a simple neural network to classify images into 3 categories. The model shows 90% training accuracy but only 75% validation accuracy. You want to understand where the model makes mistakes by analyzing the confusion matrix.
Current Metrics:Training accuracy: 90%, Validation accuracy: 75%
Issue:The validation accuracy is much lower than training accuracy, indicating possible overfitting or class confusion. You need to analyze the confusion matrix to identify which classes are confused.
Your Task
Generate and analyze the confusion matrix for the validation set predictions. Identify which classes the model confuses the most.
Use TensorFlow and Keras only.
Do not change the model architecture or training parameters.
Use the validation dataset only for confusion matrix analysis.
Hint 1
Hint 2
Hint 3
Hint 4
Solution
TensorFlow
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten
from sklearn.metrics import confusion_matrix
import numpy as np
import matplotlib.pyplot as plt

# Load dataset (using MNIST for example, 3 classes: digits 0,1,2)
(x_train, y_train), (x_val, y_val) = mnist.load_data()

# Filter dataset to only classes 0,1,2 for simplicity
train_filter = np.isin(y_train, [0,1,2])
val_filter = np.isin(y_val, [0,1,2])
x_train, y_train = x_train[train_filter], y_train[train_filter]
x_val, y_val = x_val[val_filter], y_val[val_filter]

# Normalize data
x_train = x_train / 255.0
x_val = x_val / 255.0

# Build simple model
model = Sequential([
    Flatten(input_shape=(28,28)),
    Dense(64, activation='relu'),
    Dense(3, activation='softmax')
])

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Train model
model.fit(x_train, y_train, epochs=5, batch_size=32, validation_data=(x_val, y_val))

# Predict on validation set
val_preds_prob = model.predict(x_val)
val_preds = np.argmax(val_preds_prob, axis=1)

# Compute confusion matrix
cm = confusion_matrix(y_val, val_preds)

# Plot confusion matrix
fig, ax = plt.subplots()
cax = ax.matshow(cm, cmap='Blues')
plt.title('Confusion Matrix')
fig.colorbar(cax)
ax.set_xticks([0,1,2])
ax.set_yticks([0,1,2])
ax.set_xticklabels(['0', '1', '2'])
ax.set_yticklabels(['0', '1', '2'])
plt.xlabel('Predicted')
plt.ylabel('True')

# Show numbers in matrix
for (i, j), val in np.ndenumerate(cm):
    ax.text(j, i, val, ha='center', va='center', color='red')

plt.show()
Filtered dataset to only 3 classes for clear confusion matrix.
Trained a simple neural network on filtered data.
Predicted validation labels and computed confusion matrix using sklearn.
Visualized confusion matrix with matplotlib to identify class confusions.
Fixed axis tick label setting to avoid matplotlib warning.
Results Interpretation

Before: Only overall accuracy metrics were available (Training: 90%, Validation: 75%).

After: Confusion matrix shows detailed errors between classes 0, 1, and 2. For example, class 1 might be confused with class 2 more often.

Confusion matrix helps us see exactly which classes the model confuses. This guides targeted improvements like collecting more data for confused classes or adjusting model focus.
Bonus Experiment
Try adding class weights during training to reduce confusion between the most confused classes.
💡 Hint
Calculate class weights inversely proportional to class frequencies and pass them to model.fit() using the class_weight parameter.