import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator
# Define a simplified YOLO-like model
inputs = layers.Input(shape=(224, 224, 3))
x = layers.Conv2D(16, (3,3), activation='relu', padding='same')(inputs)
x = layers.MaxPooling2D(2)(x)
x = layers.Conv2D(32, (3,3), activation='relu', padding='same')(x)
x = layers.MaxPooling2D(2)(x)
x = layers.Dropout(0.3)(x) # Added dropout
x = layers.Conv2D(64, (3,3), activation='relu', padding='same')(x)
x = layers.MaxPooling2D(2)(x)
x = layers.Dropout(0.3)(x) # Added dropout
x = layers.Flatten()(x)
x = layers.Dense(128, activation='relu')(x)
x = layers.Dropout(0.4)(x) # Added dropout
outputs = layers.Dense(10, activation='softmax')(x) # Assuming 10 classes
model = models.Model(inputs, outputs)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0005), # Lower learning rate
loss='categorical_crossentropy',
metrics=['accuracy'])
# Data augmentation
train_datagen = ImageDataGenerator(
rescale=1./255,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
horizontal_flip=True,
validation_split=0.2
)
train_generator = train_datagen.flow_from_directory(
'data/train',
target_size=(224, 224),
batch_size=32,
class_mode='categorical',
subset='training'
)
val_datagen = ImageDataGenerator(
rescale=1./255,
validation_split=0.2
)
validation_generator = val_datagen.flow_from_directory(
'data/train',
target_size=(224, 224),
batch_size=32,
class_mode='categorical',
subset='validation'
)
# Early stopping callback
early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
history = model.fit(
train_generator,
epochs=50,
validation_data=validation_generator,
callbacks=[early_stop]
)
# After training, print final metrics
final_train_acc = history.history['accuracy'][-1] * 100
final_val_acc = history.history['val_accuracy'][-1] * 100
final_val_loss = history.history['val_loss'][-1]
print(f'Training accuracy: {final_train_acc:.2f}%')
print(f'Validation accuracy: {final_val_acc:.2f}%')
print(f'Validation loss: {final_val_loss:.4f}')