Transfer learning helps you use a model trained on a big dataset to solve a new problem with a small dataset. It saves time and improves accuracy.
Transfer learning for small datasets in TensorFlow
base_model = tf.keras.applications.MobileNetV2(input_shape=(224,224,3), include_top=False, weights='imagenet') base_model.trainable = False model = tf.keras.Sequential([ base_model, tf.keras.layers.GlobalAveragePooling2D(), tf.keras.layers.Dense(1, activation='sigmoid') ]) model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
Set include_top=False to remove the original classification layer.
Freeze the base model weights by setting trainable = False to avoid changing them during training.
base_model = tf.keras.applications.ResNet50(include_top=False, weights='imagenet') base_model.trainable = False
model = tf.keras.Sequential([
base_model,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(10, activation='softmax')
])This example shows how to use MobileNetV2 pretrained on ImageNet as a base model. We freeze it and add a new output layer for binary classification. We train on a tiny dataset of 10 random images and print predictions.
import tensorflow as tf from tensorflow.keras import layers # Load base model with pretrained weights, exclude top layer base_model = tf.keras.applications.MobileNetV2(input_shape=(224,224,3), include_top=False, weights='imagenet') base_model.trainable = False # Freeze base model # Build new model on top model = tf.keras.Sequential([ base_model, layers.GlobalAveragePooling2D(), layers.Dense(1, activation='sigmoid') ]) # Compile model model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy']) # Create dummy small dataset (10 images, binary labels) x_train = tf.random.uniform((10,224,224,3)) y_train = tf.constant([0,1,0,1,0,1,0,1,0,1], dtype=tf.float32) # Train model on small dataset history = model.fit(x_train, y_train, epochs=3, verbose=2) # Make predictions predictions = model.predict(x_train) print('Predictions:', predictions.flatten())
Freezing the base model prevents losing the knowledge it already learned.
You can later unfreeze some layers to fine-tune the model if you get more data.
Use data augmentation to help small datasets generalize better.
Transfer learning uses a pretrained model to help with small datasets.
Freeze the base model layers to keep learned features.
Add new layers on top for your specific task and train only them.