import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.preprocessing.image import ImageDataGenerator
# Load pre-trained MobileNetV2 without top layers
base_model = MobileNetV2(input_shape=(128,128,3), include_top=False, weights='imagenet')
base_model.trainable = False
# Add classification head
model = models.Sequential([
base_model,
layers.GlobalAveragePooling2D(),
layers.Dropout(0.3),
layers.Dense(3, activation='softmax')
])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
# Data augmentation
train_datagen = ImageDataGenerator(
rescale=1./255,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
horizontal_flip=True,
validation_split=0.2
)
train_generator = train_datagen.flow_from_directory(
'data/train',
target_size=(128,128),
batch_size=32,
class_mode='sparse',
subset='training'
)
validation_generator = train_datagen.flow_from_directory(
'data/train',
target_size=(128,128),
batch_size=32,
class_mode='sparse',
subset='validation'
)
history = model.fit(
train_generator,
epochs=15,
validation_data=validation_generator
)
# Optionally unfreeze some layers and fine-tune
base_model.trainable = True
fine_tune_at = 100
for layer in base_model.layers[:fine_tune_at]:
layer.trainable = False
model.compile(optimizer=tf.keras.optimizers.Adam(1e-5),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
fine_tune_history = model.fit(
train_generator,
epochs=10,
validation_data=validation_generator
)