import tensorflow as tf
from tensorflow.keras import layers, models
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np
# Load example dataset (using CIFAR-10 for demonstration, only 3 classes)
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
# Filter for classes 0,1,2 only
train_filter = np.isin(y_train, [0,1,2]).flatten()
test_filter = np.isin(y_test, [0,1,2]).flatten()
x_train, y_train = x_train[train_filter], y_train[train_filter]
x_test, y_test = x_test[test_filter], y_test[test_filter]
y_train = y_train.flatten()
y_test = y_test.flatten()
# Normalize images
x_train, x_test = x_train / 255.0, x_test / 255.0
# Build model with dropout
model = models.Sequential([
layers.Conv2D(32, (3,3), activation='relu', input_shape=(32,32,3)),
layers.MaxPooling2D((2,2)),
layers.Dropout(0.25),
layers.Conv2D(64, (3,3), activation='relu'),
layers.MaxPooling2D((2,2)),
layers.Dropout(0.25),
layers.Flatten(),
layers.Dense(64, activation='relu'),
layers.Dropout(0.5),
layers.Dense(3, activation='softmax')
])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
# Early stopping callback
early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
# Train with validation split
history = model.fit(x_train, y_train, epochs=50, batch_size=64, validation_split=0.2, callbacks=[early_stop], verbose=0)
# Evaluate on test set
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=0)
# Predict classes
y_pred_probs = model.predict(x_test)
y_pred = np.argmax(y_pred_probs, axis=1)
# Classification report and confusion matrix
class_report = classification_report(y_test, y_pred, output_dict=True)
conf_matrix = confusion_matrix(y_test, y_pred)
# Prepare output metrics string
metrics_str = f"Training accuracy (last epoch): {history.history['accuracy'][-1]*100:.2f}%, Validation accuracy (best): {max(history.history['val_accuracy'])*100:.2f}%, Test accuracy: {test_acc*100:.2f}%, Test loss: {test_loss:.4f}"
# Print classification report and confusion matrix for learner
print(metrics_str)
print("Classification Report:")
print(classification_report(y_test, y_pred))
print("Confusion Matrix:")
print(conf_matrix)