Feature extraction helps us use important parts of data to teach a model faster and better. It saves time and improves results by focusing on useful information.
Feature extraction strategy in PyTorch
import torch import torchvision.models as models # Load a pre-trained model model = models.resnet18(pretrained=True) # Freeze all layers to prevent training for param in model.parameters(): param.requires_grad = False # Replace the final layer to match your task model.fc = torch.nn.Linear(model.fc.in_features, num_classes) # Now only the final layer will be trained
Freezing layers means their weights won't change during training.
Replacing the final layer adapts the model to your specific problem.
model = models.resnet18(pretrained=True) for param in model.parameters(): param.requires_grad = False model.fc = torch.nn.Linear(model.fc.in_features, 10)
model = models.vgg16(pretrained=True) for param in model.features.parameters(): param.requires_grad = False model.classifier[6] = torch.nn.Linear(4096, 5)
This code loads a pre-trained ResNet18, freezes all layers, replaces the last layer for 3 classes, and runs dummy data through it. It prints the output shape and how many parameters will be trained (should be 1 layer).
import torch import torchvision.models as models import torch.nn as nn # Number of classes for new task num_classes = 3 # Load pre-trained ResNet18 model = models.resnet18(pretrained=True) # Freeze all layers for param in model.parameters(): param.requires_grad = False # Replace final fully connected layer model.fc = nn.Linear(model.fc.in_features, num_classes) # Create dummy input (batch size 2, 3 color channels, 224x224 image) dummy_input = torch.randn(2, 3, 224, 224) # Get output predictions output = model(dummy_input) # Print output shape and requires_grad status of parameters print(f"Output shape: {output.shape}") trainable_params = [p for p in model.parameters() if p.requires_grad] print(f"Number of trainable parameters: {len(trainable_params)}")
Freezing layers helps keep learned features and reduces training time.
Only the replaced final layer's parameters require gradients and will update during training.
Use dummy inputs with correct shape to test model output before training.
Feature extraction uses pre-trained models to get useful data features.
Freeze layers to keep their knowledge and train only new parts.
Replace the final layer to fit your specific task.