0
0
Computer-visionHow-ToBeginner ยท 4 min read

How to Use Transfer Learning for Images in Computer Vision

Use transfer learning by loading a pretrained image model like ResNet50, replacing its final layer to match your task, and training only the new layers or fine-tuning the whole model on your dataset. This approach leverages learned features from large datasets to improve accuracy and reduce training time.
๐Ÿ“

Syntax

Transfer learning typically involves these steps:

  • Load a pretrained model with weights trained on a large dataset (e.g., ImageNet).
  • Freeze the base layers to keep learned features or allow fine-tuning.
  • Replace the final classification layer to fit your number of classes.
  • Compile and train the model on your new dataset.
python
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model

num_classes = 10  # Define number of classes

# Load pretrained base model without top layer
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))

# Freeze base model layers
base_model.trainable = False

# Add new classification head
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
predictions = Dense(num_classes, activation='softmax')(x)

# Create new model
model = Model(inputs=base_model.input, outputs=predictions)

# Compile model
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
๐Ÿ’ป

Example

This example shows how to use transfer learning with ResNet50 to classify images into 5 classes. It freezes the base model, adds new layers, and trains only the new layers.

python
import tensorflow as tf
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.image import ImageDataGenerator

num_classes = 5
batch_size = 16

# Load pretrained ResNet50 without top layer
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
base_model.trainable = False

# Add new layers
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(512, activation='relu')(x)
predictions = Dense(num_classes, activation='softmax')(x)
model = Model(inputs=base_model.input, outputs=predictions)

# Compile model
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# Prepare dummy data generator for demonstration
train_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow(
    tf.random.uniform([batch_size*10, 224, 224, 3]),
    tf.keras.utils.to_categorical(tf.random.uniform([batch_size*10], maxval=num_classes, dtype=tf.int32), num_classes),
    batch_size=batch_size
)

# Train model for 1 epoch
history = model.fit(train_generator, steps_per_epoch=10, epochs=1)

# Show training accuracy
print(f"Training accuracy after 1 epoch: {history.history['accuracy'][0]:.4f}")
Output
Epoch 1/1 10/10 [==============================] - 12s 1s/step - loss: 1.6092 - accuracy: 0.1875 Training accuracy after 1 epoch: 0.1875
โš ๏ธ

Common Pitfalls

  • Not freezing base layers: Training all layers from scratch can cause overfitting and slow training.
  • Mismatched input size: Pretrained models expect specific input sizes (e.g., 224x224 for ResNet50).
  • Wrong number of output classes: The final layer must match your dataset's classes.
  • Not preprocessing inputs: Inputs must be normalized or preprocessed as the pretrained model expects.
python
from tensorflow.keras.applications.resnet50 import preprocess_input

# Wrong way: no preprocessing
# predictions = model.predict(raw_images)

# Right way: preprocess images
# processed_images = preprocess_input(raw_images)
# predictions = model.predict(processed_images)
๐Ÿ“Š

Quick Reference

Key tips for transfer learning in image tasks:

  • Use pretrained models like ResNet, VGG, or MobileNet.
  • Freeze base layers initially, then optionally fine-tune later.
  • Replace the top layer to match your classes.
  • Preprocess inputs with model-specific functions.
  • Use data augmentation to improve generalization.
โœ…

Key Takeaways

Load a pretrained image model and replace its final layer to fit your task.
Freeze base layers to keep learned features and train only new layers initially.
Preprocess input images as required by the pretrained model.
Fine-tune base layers later if more accuracy is needed.
Use data augmentation to help your model generalize better.