import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.callbacks import EarlyStopping
# Load dataset
(x_train, _), (x_test, _) = tf.keras.datasets.mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = x_train.reshape((len(x_train), 28, 28, 1))
x_test = x_test.reshape((len(x_test), 28, 28, 1))
# Define autoencoder model with dropout and smaller hidden layers
input_img = layers.Input(shape=(28, 28, 1))
# Encoder
x = layers.Conv2D(16, (3, 3), activation='relu', padding='same')(input_img)
x = layers.MaxPooling2D((2, 2), padding='same')(x)
x = layers.Dropout(0.2)(x)
x = layers.Conv2D(8, (3, 3), activation='relu', padding='same')(x)
x = layers.MaxPooling2D((2, 2), padding='same')(x)
x = layers.Dropout(0.2)(x)
# Bottleneck
encoded = layers.Conv2D(8, (3, 3), activation='relu', padding='same')(x)
# Decoder
x = layers.Conv2D(8, (3, 3), activation='relu', padding='same')(encoded)
x = layers.UpSampling2D((2, 2))(x)
x = layers.Dropout(0.2)(x)
x = layers.Conv2D(16, (3, 3), activation='relu', padding='same')(x)
x = layers.UpSampling2D((2, 2))(x)
x = layers.Dropout(0.2)(x)
decoded = layers.Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x)
autoencoder = models.Model(input_img, decoded)
# Compile model with lower learning rate
autoencoder.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001), loss='binary_crossentropy', metrics=['accuracy'])
# Early stopping callback
early_stop = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
# Train model
history = autoencoder.fit(
x_train, x_train,
epochs=50,
batch_size=128,
shuffle=True,
validation_data=(x_test, x_test),
callbacks=[early_stop]
)
# Evaluate final metrics
train_loss, train_acc = autoencoder.evaluate(x_train, x_train, verbose=0)
val_loss, val_acc = autoencoder.evaluate(x_test, x_test, verbose=0)
print(f"Training loss: {train_loss:.3f}, Training accuracy: {train_acc*100:.1f}%")
print(f"Validation loss: {val_loss:.3f}, Validation accuracy: {val_acc*100:.1f}%")