Bird
Raised Fist0
NLPml~20 mins

Model optimization (distillation, quantization) in NLP - ML Experiment: Train & Evaluate

Choose your learning style10 modes available

Start learning this pattern below

Jump into concepts and practice - no test required

or
Recommended
Test this pattern10 questions across easy, medium, and hard to know if this pattern is strong
Experiment - Model optimization (distillation, quantization)
Problem:You have a large NLP model that performs well on text classification but is too slow and large for deployment on mobile devices.
Current Metrics:Training accuracy: 95%, Validation accuracy: 93%, Model size: 500MB, Inference time per sample: 500ms
Issue:The model is too large and slow, making it unsuitable for mobile deployment despite good accuracy.
Your Task
Reduce the model size and inference time by applying model distillation and quantization while keeping validation accuracy above 90%.
You must keep the original training data and model architecture for the teacher model unchanged.
You can only modify the student model size and apply quantization after training.
Validation accuracy must not drop below 90%.
Hint 1
Hint 2
Hint 3
Solution
NLP
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import BertTokenizer, BertForSequenceClassification, DistilBertForSequenceClassification

# Assume we have a dataset class and data loaders
# teacher_model: pretrained large BERT
# student_model: smaller BERT or DistilBERT

class DistillationLoss(nn.Module):
    def __init__(self, temperature=2.0, alpha=0.5):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.kl_div = nn.KLDivLoss(reduction='batchmean')
        self.ce = nn.CrossEntropyLoss()

    def forward(self, student_logits, teacher_logits, labels):
        soft_targets = nn.functional.log_softmax(student_logits / self.temperature, dim=1)
        soft_labels = nn.functional.softmax(teacher_logits / self.temperature, dim=1)
        distill_loss = self.kl_div(soft_targets, soft_labels) * (self.temperature ** 2)
        student_loss = self.ce(student_logits, labels)
        return self.alpha * student_loss + (1 - self.alpha) * distill_loss

# Load teacher model (large)
teacher_model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
teacher_model.eval()

# Load student model (smaller)
student_model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=2)

optimizer = optim.Adam(student_model.parameters(), lr=5e-5)
criterion = DistillationLoss(temperature=2.0, alpha=0.5)

# Training loop for distillation
for epoch in range(3):
    student_model.train()
    for batch in train_loader:
        inputs = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']

        with torch.no_grad():
            teacher_logits = teacher_model(inputs, attention_mask=attention_mask).logits

        student_logits = student_model(inputs, attention_mask=attention_mask).logits

        loss = criterion(student_logits, teacher_logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# Evaluate student model accuracy on validation set
student_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for batch in val_loader:
        inputs = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']
        outputs = student_model(inputs, attention_mask=attention_mask).logits
        preds = torch.argmax(outputs, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
val_accuracy = correct / total * 100

# Apply post-training dynamic quantization
quantized_model = torch.quantization.quantize_dynamic(
    student_model, {nn.Linear}, dtype=torch.qint8
)

# Measure model size and inference time
import time
import os

def model_size(model):
    torch.save(model.state_dict(), 'temp.pth')
    size = os.path.getsize('temp.pth') / 1e6  # MB
    os.remove('temp.pth')
    return size

size_before = model_size(student_model)
size_after = model_size(quantized_model)

# Inference time measurement
sample_batch = next(iter(val_loader))
input_sample = sample_batch['input_ids']
attention_sample = sample_batch['attention_mask']

start = time.time()
_ = student_model(input_sample, attention_mask=attention_sample)
end = time.time()
inference_time_before = (end - start) / input_sample.size(0) * 1000  # ms per sample

start = time.time()
_ = quantized_model(input_sample, attention_mask=attention_sample)
end = time.time()
inference_time_after = (end - start) / input_sample.size(0) * 1000  # ms per sample

print(f"Validation accuracy after distillation: {val_accuracy:.2f}%")
print(f"Model size before quantization: {size_before:.2f} MB")
print(f"Model size after quantization: {size_after:.2f} MB")
print(f"Inference time before quantization: {inference_time_before:.2f} ms/sample")
print(f"Inference time after quantization: {inference_time_after:.2f} ms/sample")
Trained a smaller student model using knowledge distillation from the large teacher model to reduce size while keeping accuracy.
Applied dynamic quantization to the student model to further reduce model size and speed up inference.
Measured validation accuracy, model size, and inference time before and after optimization.
Results Interpretation

Before Optimization: Accuracy 93%, Size 500MB, Inference 500ms/sample

After Optimization: Accuracy 91.5%, Size 75MB, Inference 120ms/sample

Model distillation and quantization can greatly reduce model size and speed up inference with only a small drop in accuracy, making models practical for deployment on limited devices.
Bonus Experiment
Try pruning the student model weights after distillation to further reduce size and compare accuracy.
💡 Hint
Use magnitude-based pruning to zero out small weights and fine-tune the model to recover accuracy.

Practice

(1/5)
1. What is the main goal of model distillation in NLP?
easy
A. To increase the number of layers in a neural network
B. To add more training data for better accuracy
C. To convert text data into numerical vectors
D. To train a smaller model to mimic a larger model's behavior

Solution

  1. Step 1: Understand model distillation concept

    Model distillation is about making a smaller model learn from a bigger, well-trained model.
  2. Step 2: Identify the goal of distillation

    The goal is to keep performance while reducing model size and complexity.
  3. Final Answer:

    To train a smaller model to mimic a larger model's behavior -> Option D
  4. Quick Check:

    Distillation = smaller model copies bigger model [OK]
Hint: Distillation means small model learns from big model [OK]
Common Mistakes:
  • Confusing distillation with adding layers
  • Thinking distillation increases data size
  • Mixing distillation with data preprocessing
2. Which of the following is the correct way to apply quantization to a model's weights in Python using PyTorch?
easy
A. model.quantize(weights=True)
B. torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
C. torch.quantize(model, dtype=torch.float32)
D. torch.quantization(model, dtype=torch.int32)

Solution

  1. Step 1: Recall PyTorch quantization syntax

    PyTorch uses torch.quantization.quantize_dynamic for dynamic quantization on layers like Linear.
  2. Step 2: Check correct function and parameters

    torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8) correctly calls quantize_dynamic with model, target layers, and dtype torch.qint8.
  3. Final Answer:

    torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8) -> Option B
  4. Quick Check:

    PyTorch quantize_dynamic with Linear and qint8 = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8) [OK]
Hint: Use torch.quantization.quantize_dynamic for quantization [OK]
Common Mistakes:
  • Using non-existent torch.quantize function
  • Passing wrong dtype like float32 instead of qint8
  • Calling quantization as a model method
3. Given the following code snippet for distillation, what will be the output loss value if the student model perfectly mimics the teacher model's outputs?
teacher_outputs = torch.tensor([0.1, 0.9])
student_outputs = torch.tensor([0.1, 0.9])
loss_fn = torch.nn.MSELoss()
loss = loss_fn(student_outputs, teacher_outputs)
print(loss.item())
medium
A. 0.0
B. 0.5
C. 1.0
D. Cannot compute due to shape mismatch

Solution

  1. Step 1: Understand MSELoss calculation

    MSELoss calculates mean squared error between student and teacher outputs.
  2. Step 2: Calculate loss for identical outputs

    Since student_outputs equals teacher_outputs, difference is zero, so loss is 0.0.
  3. Final Answer:

    0.0 -> Option A
  4. Quick Check:

    Identical outputs give zero MSE loss [OK]
Hint: Same outputs mean zero loss in MSE [OK]
Common Mistakes:
  • Assuming loss is 1.0 by default
  • Confusing loss with accuracy
  • Thinking shape mismatch error occurs
4. You tried to quantize a model but got an error: AttributeError: 'MyModel' object has no attribute 'quantize'. What is the likely cause?
medium
A. The model class does not have a built-in quantize method
B. You forgot to import torch
C. Quantization only works on CPU, not GPU
D. The model is already quantized

Solution

  1. Step 1: Analyze the error message

    The error says the model object lacks a 'quantize' method, meaning it is not defined.
  2. Step 2: Understand quantization usage

    Quantization is applied via PyTorch functions, not as a model method, so calling model.quantize() causes error.
  3. Final Answer:

    The model class does not have a built-in quantize method -> Option A
  4. Quick Check:

    Quantize is a function, not a model method [OK]
Hint: Quantize via torch functions, not model methods [OK]
Common Mistakes:
  • Trying to call quantize as model.quantize()
  • Ignoring import errors
  • Assuming quantization only works on CPU
5. You want to deploy a chatbot on a mobile device with limited memory and CPU. Which combination of model optimization techniques is best to reduce size and speed up inference without losing much accuracy?
hard
A. Use quantization first, then retrain the large model from scratch
B. Only increase the training data size to improve accuracy
C. Use distillation to train a smaller model, then apply quantization to reduce precision
D. Add more layers to the model and use float64 precision

Solution

  1. Step 1: Identify constraints and goals

    Mobile devices need small, fast models with good accuracy.
  2. Step 2: Choose suitable optimization techniques

    Distillation creates a smaller model; quantization reduces number precision to save space and speed up inference.
  3. Step 3: Combine techniques for best effect

    Using distillation first then quantization is a common, effective approach.
  4. Final Answer:

    Use distillation to train a smaller model, then apply quantization to reduce precision -> Option C
  5. Quick Check:

    Distillation + quantization = small, fast, accurate model [OK]
Hint: Distill first, then quantize for mobile deployment [OK]
Common Mistakes:
  • Ignoring quantization for mobile
  • Adding layers increases size and slows down
  • Retraining large model after quantization wastes effort