import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator
# Define a simple CNN model for image inpainting
input_img = layers.Input(shape=(64, 64, 3))
# Encoder
x = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(input_img)
x = layers.MaxPooling2D((2, 2), padding='same')(x)
x = layers.Dropout(0.3)(x) # Added dropout
x = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)
x = layers.MaxPooling2D((2, 2), padding='same')(x)
x = layers.Dropout(0.3)(x) # Added dropout
# Decoder
x = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)
x = layers.UpSampling2D((2, 2))(x)
x = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(x)
x = layers.UpSampling2D((2, 2))(x)
output_img = layers.Conv2D(3, (3, 3), activation='sigmoid', padding='same')(x)
model = models.Model(input_img, output_img)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss='mse')
# Data augmentation to increase training data variety
train_datagen = ImageDataGenerator(
rotation_range=10,
width_shift_range=0.1,
height_shift_range=0.1,
horizontal_flip=True
)
# Assume X_train and Y_train are numpy arrays of images and their masked versions
# For demonstration, placeholders are used
import numpy as np
X_train = np.random.rand(100, 64, 64, 3).astype('float32')
Y_train = np.random.rand(100, 64, 64, 3).astype('float32')
batch_size = 16
# Early stopping callback
early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
# Split data for training and validation
split_idx = int(0.8 * len(X_train))
X_tr = X_train[:split_idx]
X_val = X_train[split_idx:]
Y_tr = Y_train[:split_idx]
Y_val = Y_train[split_idx:]
# Data generators
train_generator = train_datagen.flow(X_tr, Y_tr, batch_size=batch_size)
val_datagen = ImageDataGenerator()
val_generator = val_datagen.flow(X_val, Y_val, batch_size=batch_size)
# Fit model with data augmentation
model.fit(
train_generator,
epochs=50,
validation_data=val_generator,
callbacks=[early_stop]
)