import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator
# Define U-Net model with dropout
def unet_model(input_size=(128, 128, 1)):
inputs = layers.Input(input_size)
# Encoder
c1 = layers.Conv2D(16, (3, 3), activation='relu', padding='same')(inputs)
c1 = layers.Dropout(0.1)(c1)
c1 = layers.Conv2D(16, (3, 3), activation='relu', padding='same')(c1)
p1 = layers.MaxPooling2D((2, 2))(c1)
c2 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(p1)
c2 = layers.Dropout(0.1)(c2)
c2 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(c2)
p2 = layers.MaxPooling2D((2, 2))(c2)
# Bottleneck
c3 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(p2)
c3 = layers.Dropout(0.2)(c3)
c3 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(c3)
# Decoder
u4 = layers.Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(c3)
u4 = layers.concatenate([u4, c2])
c4 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(u4)
c4 = layers.Dropout(0.1)(c4)
c4 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(c4)
u5 = layers.Conv2DTranspose(16, (2, 2), strides=(2, 2), padding='same')(c4)
u5 = layers.concatenate([u5, c1])
c5 = layers.Conv2D(16, (3, 3), activation='relu', padding='same')(u5)
c5 = layers.Dropout(0.1)(c5)
c5 = layers.Conv2D(16, (3, 3), activation='relu', padding='same')(c5)
outputs = layers.Conv2D(1, (1, 1), activation='sigmoid')(c5)
model = models.Model(inputs=[inputs], outputs=[outputs])
return model
# Create data augmentation generator
train_datagen = ImageDataGenerator(
rotation_range=10,
width_shift_range=0.1,
height_shift_range=0.1,
zoom_range=0.1,
horizontal_flip=True
)
# Assume X_train, y_train, X_val, y_val are preloaded numpy arrays of images and masks
model = unet_model()
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
# Fit model with data augmentation
batch_size = 16
train_generator = train_datagen.flow(X_train, y_train, batch_size=batch_size)
history = model.fit(
train_generator,
steps_per_epoch=len(X_train) // batch_size,
epochs=50,
validation_data=(X_val, y_val)
)