0
0
NLPml~20 mins

Model optimization (distillation, quantization) in NLP - ML Experiment: Train & Evaluate

Choose your learning style9 modes available
Experiment - Model optimization (distillation, quantization)
Problem:You have a large NLP model that performs well on text classification but is too slow and large for deployment on mobile devices.
Current Metrics:Training accuracy: 95%, Validation accuracy: 93%, Model size: 500MB, Inference time per sample: 500ms
Issue:The model is too large and slow, making it unsuitable for mobile deployment despite good accuracy.
Your Task
Reduce the model size and inference time by applying model distillation and quantization while keeping validation accuracy above 90%.
You must keep the original training data and model architecture for the teacher model unchanged.
You can only modify the student model size and apply quantization after training.
Validation accuracy must not drop below 90%.
Hint 1
Hint 2
Hint 3
Solution
NLP
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import BertTokenizer, BertForSequenceClassification, DistilBertForSequenceClassification

# Assume we have a dataset class and data loaders
# teacher_model: pretrained large BERT
# student_model: smaller BERT or DistilBERT

class DistillationLoss(nn.Module):
    def __init__(self, temperature=2.0, alpha=0.5):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.kl_div = nn.KLDivLoss(reduction='batchmean')
        self.ce = nn.CrossEntropyLoss()

    def forward(self, student_logits, teacher_logits, labels):
        soft_targets = nn.functional.log_softmax(student_logits / self.temperature, dim=1)
        soft_labels = nn.functional.softmax(teacher_logits / self.temperature, dim=1)
        distill_loss = self.kl_div(soft_targets, soft_labels) * (self.temperature ** 2)
        student_loss = self.ce(student_logits, labels)
        return self.alpha * student_loss + (1 - self.alpha) * distill_loss

# Load teacher model (large)
teacher_model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
teacher_model.eval()

# Load student model (smaller)
student_model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=2)

optimizer = optim.Adam(student_model.parameters(), lr=5e-5)
criterion = DistillationLoss(temperature=2.0, alpha=0.5)

# Training loop for distillation
for epoch in range(3):
    student_model.train()
    for batch in train_loader:
        inputs = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']

        with torch.no_grad():
            teacher_logits = teacher_model(inputs, attention_mask=attention_mask).logits

        student_logits = student_model(inputs, attention_mask=attention_mask).logits

        loss = criterion(student_logits, teacher_logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# Evaluate student model accuracy on validation set
student_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for batch in val_loader:
        inputs = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']
        outputs = student_model(inputs, attention_mask=attention_mask).logits
        preds = torch.argmax(outputs, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
val_accuracy = correct / total * 100

# Apply post-training dynamic quantization
quantized_model = torch.quantization.quantize_dynamic(
    student_model, {nn.Linear}, dtype=torch.qint8
)

# Measure model size and inference time
import time
import os

def model_size(model):
    torch.save(model.state_dict(), 'temp.pth')
    size = os.path.getsize('temp.pth') / 1e6  # MB
    os.remove('temp.pth')
    return size

size_before = model_size(student_model)
size_after = model_size(quantized_model)

# Inference time measurement
sample_batch = next(iter(val_loader))
input_sample = sample_batch['input_ids']
attention_sample = sample_batch['attention_mask']

start = time.time()
_ = student_model(input_sample, attention_mask=attention_sample)
end = time.time()
inference_time_before = (end - start) / input_sample.size(0) * 1000  # ms per sample

start = time.time()
_ = quantized_model(input_sample, attention_mask=attention_sample)
end = time.time()
inference_time_after = (end - start) / input_sample.size(0) * 1000  # ms per sample

print(f"Validation accuracy after distillation: {val_accuracy:.2f}%")
print(f"Model size before quantization: {size_before:.2f} MB")
print(f"Model size after quantization: {size_after:.2f} MB")
print(f"Inference time before quantization: {inference_time_before:.2f} ms/sample")
print(f"Inference time after quantization: {inference_time_after:.2f} ms/sample")
Trained a smaller student model using knowledge distillation from the large teacher model to reduce size while keeping accuracy.
Applied dynamic quantization to the student model to further reduce model size and speed up inference.
Measured validation accuracy, model size, and inference time before and after optimization.
Results Interpretation

Before Optimization: Accuracy 93%, Size 500MB, Inference 500ms/sample

After Optimization: Accuracy 91.5%, Size 75MB, Inference 120ms/sample

Model distillation and quantization can greatly reduce model size and speed up inference with only a small drop in accuracy, making models practical for deployment on limited devices.
Bonus Experiment
Try pruning the student model weights after distillation to further reduce size and compare accuracy.
💡 Hint
Use magnitude-based pruning to zero out small weights and fine-tune the model to recover accuracy.