0
0
TensorFlowml~5 mins

Transfer learning for small datasets in TensorFlow

Choose your learning style9 modes available
Introduction

Transfer learning helps you use a model trained on a big dataset to solve a new problem with a small dataset. It saves time and improves accuracy.

You have only a few images to train a model for recognizing objects.
You want to build a speech recognition system but lack lots of voice data.
You need to classify medical images but have limited labeled examples.
You want to quickly create a model for a new task without starting from scratch.
Syntax
TensorFlow
base_model = tf.keras.applications.MobileNetV2(input_shape=(224,224,3), include_top=False, weights='imagenet')
base_model.trainable = False

model = tf.keras.Sequential([
    base_model,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(1, activation='sigmoid')
])

model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

Set include_top=False to remove the original classification layer.

Freeze the base model weights by setting trainable = False to avoid changing them during training.

Examples
Using ResNet50 as the base model and freezing its layers.
TensorFlow
base_model = tf.keras.applications.ResNet50(include_top=False, weights='imagenet')
base_model.trainable = False
Adding a new output layer for 10 classes after the base model.
TensorFlow
model = tf.keras.Sequential([
    base_model,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(10, activation='softmax')
])
Sample Model

This example shows how to use MobileNetV2 pretrained on ImageNet as a base model. We freeze it and add a new output layer for binary classification. We train on a tiny dataset of 10 random images and print predictions.

TensorFlow
import tensorflow as tf
from tensorflow.keras import layers

# Load base model with pretrained weights, exclude top layer
base_model = tf.keras.applications.MobileNetV2(input_shape=(224,224,3), include_top=False, weights='imagenet')
base_model.trainable = False  # Freeze base model

# Build new model on top
model = tf.keras.Sequential([
    base_model,
    layers.GlobalAveragePooling2D(),
    layers.Dense(1, activation='sigmoid')
])

# Compile model
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# Create dummy small dataset (10 images, binary labels)
x_train = tf.random.uniform((10,224,224,3))
y_train = tf.constant([0,1,0,1,0,1,0,1,0,1], dtype=tf.float32)

# Train model on small dataset
history = model.fit(x_train, y_train, epochs=3, verbose=2)

# Make predictions
predictions = model.predict(x_train)
print('Predictions:', predictions.flatten())
OutputSuccess
Important Notes

Freezing the base model prevents losing the knowledge it already learned.

You can later unfreeze some layers to fine-tune the model if you get more data.

Use data augmentation to help small datasets generalize better.

Summary

Transfer learning uses a pretrained model to help with small datasets.

Freeze the base model layers to keep learned features.

Add new layers on top for your specific task and train only them.