0
0
TensorFlowml~20 mins

Error analysis patterns in TensorFlow - ML Experiment: Train & Evaluate

Choose your learning style9 modes available
Experiment - Error analysis patterns
Problem:You trained a neural network to classify images into 3 categories. The training accuracy is 95%, but validation accuracy is only 70%. You want to understand why the model makes mistakes.
Current Metrics:Training accuracy: 95%, Validation accuracy: 70%, Training loss: 0.15, Validation loss: 0.85
Issue:The model overfits the training data and performs poorly on validation data. You need to analyze the errors to find patterns causing low validation accuracy.
Your Task
Perform error analysis on the validation set predictions to identify common patterns in misclassified images and suggest improvements.
Use TensorFlow and Python only.
Do not change the model architecture or training hyperparameters yet.
Focus on analyzing errors and visualizing them.
Hint 1
Hint 2
Hint 3
Hint 4
Solution
TensorFlow
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

# Assume model and validation data are loaded
# model: trained TensorFlow model
# val_images, val_labels: validation dataset

# Get predictions
val_pred_probs = model.predict(val_images)
val_preds = np.argmax(val_pred_probs, axis=1)
true_labels = np.argmax(val_labels, axis=1)

# Calculate confusion matrix
cm = confusion_matrix(true_labels, val_preds)

# Display confusion matrix
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["Class 0", "Class 1", "Class 2"])
disp.plot(cmap=plt.cm.Blues)
plt.title("Confusion Matrix on Validation Set")
plt.show()

# Calculate per-class accuracy
per_class_acc = cm.diagonal() / cm.sum(axis=1)
for i, acc in enumerate(per_class_acc):
    print(f"Accuracy for Class {i}: {acc*100:.2f}%")

# Find indices of misclassified samples
misclassified_indices = np.where(val_preds != true_labels)[0]

# Visualize some misclassified images
num_to_show = 6
plt.figure(figsize=(12, 6))
for i, idx in enumerate(misclassified_indices[:num_to_show]):
    plt.subplot(2, 3, i+1)
    plt.imshow(val_images[idx])
    plt.title(f"True: {true_labels[idx]}, Pred: {val_preds[idx]}")
    plt.axis('off')
plt.suptitle("Examples of Misclassified Validation Images")
plt.show()
Added code to predict on validation data and compare with true labels.
Computed and visualized confusion matrix to identify which classes are confused.
Calculated per-class accuracy to find weak classes.
Displayed sample misclassified images to visually inspect error patterns.
Results Interpretation

Before error analysis: You only knew overall validation accuracy was low (70%).

After error analysis: You see which classes the model confuses most from the confusion matrix. Per-class accuracy shows which classes need improvement. Visualizing misclassified images helps identify if errors are due to ambiguous images, poor lighting, or similar features.

Error analysis helps you understand why your model makes mistakes. Instead of guessing, you use data-driven insights to guide improvements like collecting more data for weak classes, augmenting images, or adjusting the model.
Bonus Experiment
Try adding data augmentation to the training data and retrain the model to see if validation accuracy improves for the weak classes identified.
💡 Hint
Use TensorFlow's ImageDataGenerator or tf.image functions to add random flips, rotations, or brightness changes during training.