import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten
from sklearn.metrics import confusion_matrix
import numpy as np
import matplotlib.pyplot as plt
# Load dataset (using MNIST for example, 3 classes: digits 0,1,2)
(x_train, y_train), (x_val, y_val) = mnist.load_data()
# Filter dataset to only classes 0,1,2 for simplicity
train_filter = np.isin(y_train, [0,1,2])
val_filter = np.isin(y_val, [0,1,2])
x_train, y_train = x_train[train_filter], y_train[train_filter]
x_val, y_val = x_val[val_filter], y_val[val_filter]
# Normalize data
x_train = x_train / 255.0
x_val = x_val / 255.0
# Build simple model
model = Sequential([
Flatten(input_shape=(28,28)),
Dense(64, activation='relu'),
Dense(3, activation='softmax')
])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
# Train model
model.fit(x_train, y_train, epochs=5, batch_size=32, validation_data=(x_val, y_val))
# Predict on validation set
val_preds_prob = model.predict(x_val)
val_preds = np.argmax(val_preds_prob, axis=1)
# Compute confusion matrix
cm = confusion_matrix(y_val, val_preds)
# Plot confusion matrix
fig, ax = plt.subplots()
cax = ax.matshow(cm, cmap='Blues')
plt.title('Confusion Matrix')
fig.colorbar(cax)
ax.set_xticks([0,1,2])
ax.set_yticks([0,1,2])
ax.set_xticklabels(['0', '1', '2'])
ax.set_yticklabels(['0', '1', '2'])
plt.xlabel('Predicted')
plt.ylabel('True')
# Show numbers in matrix
for (i, j), val in np.ndenumerate(cm):
ax.text(j, i, val, ha='center', va='center', color='red')
plt.show()