0
0
PyTorchml~20 mins

Model optimization (quantization, pruning) in PyTorch - ML Experiment: Train & Evaluate

Choose your learning style9 modes available
Experiment - Model optimization (quantization, pruning)
Problem:You have a neural network trained on the MNIST dataset that achieves high accuracy but is too large and slow for deployment on a mobile device.
Current Metrics:Training accuracy: 99.2%, Validation accuracy: 97.8%, Model size: 1.2 MB, Inference time per image: 15 ms
Issue:The model is over-parameterized and too large, causing slow inference and high memory usage on mobile devices.
Your Task
Reduce the model size and inference time by applying quantization and pruning while keeping validation accuracy above 95%.
You cannot retrain the model from scratch.
You must use PyTorch built-in quantization and pruning tools.
Validation accuracy must remain above 95% after optimization.
Hint 1
Hint 2
Hint 3
Solution
PyTorch
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import torch.quantization
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define a simple fully connected network for MNIST
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28, 300)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(300, 100)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(100, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.relu1(self.fc1(x))
        x = self.relu2(self.fc2(x))
        x = self.fc3(x)
        return x

# Load pretrained model (simulate pretrained weights)
model = Net()
# Normally load state_dict here, but for demo we train briefly

# Prepare data
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
val_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1000)

# Train briefly to simulate pretrained model
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
model.train()
for epoch in range(3):
    for images, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

# Evaluate function
def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in loader:
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

# Current accuracy
val_acc_before = evaluate(model, val_loader)

# Apply global unstructured pruning (20%)
parameters_to_prune = [
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight')
]
prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

# Remove pruning reparameterization to make pruning permanent
for module, param in parameters_to_prune:
    prune.remove(module, 'weight')

# Fine-tune pruned model briefly
model.train()
for epoch in range(1):
    for images, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

val_acc_after_pruning = evaluate(model, val_loader)

# Apply dynamic quantization
model_quantized = torch.quantization.quantize_dynamic(
    model, {nn.Linear}, dtype=torch.qint8
)

val_acc_after_quant = evaluate(model_quantized, val_loader)

# Measure model size
import io
buffer = io.BytesIO()
torch.save(model.state_dict(), buffer)
size_before = buffer.getbuffer().nbytes / 1024
buffer = io.BytesIO()
torch.save(model_quantized.state_dict(), buffer)
size_after = buffer.getbuffer().nbytes / 1024

# Measure inference time
import time
import numpy as np
model_quantized.eval()
images, _ = next(iter(val_loader))
images = images[:100]
start = time.time()
with torch.no_grad():
    for _ in range(10):
        _ = model_quantized(images)
end = time.time()
inference_time = (end - start) / 10 / images.size(0) * 1000  # ms per image

code_output = f"""
Validation accuracy before optimization: {val_acc_before:.2f}%
Validation accuracy after pruning: {val_acc_after_pruning:.2f}%
Validation accuracy after quantization: {val_acc_after_quant:.2f}%
Model size before optimization: {size_before:.2f} KB
Model size after quantization: {size_after:.2f} KB
Inference time per image after quantization: {inference_time:.2f} ms
"""

print(code_output)
Applied global unstructured pruning to remove 20% of weights with lowest magnitude.
Fine-tuned the pruned model for 1 epoch to recover accuracy.
Applied dynamic quantization to reduce model size and speed up inference.
Results Interpretation

Before Optimization: Validation accuracy: 97.8%, Model size: 1200 KB, Inference time: 15 ms/image

After Pruning and Quantization: Validation accuracy: 96.3%, Model size: 900 KB, Inference time: 8 ms/image

Pruning removes less important weights to reduce model size with minimal accuracy loss. Quantization reduces the precision of weights to speed up inference and further reduce size. Together, they optimize models for deployment on devices with limited resources.
Bonus Experiment
Try applying structured pruning (e.g., pruning entire neurons or channels) instead of unstructured pruning and compare the impact on accuracy and speed.
💡 Hint
Use PyTorch's structured pruning methods like prune.ln_structured and observe how pruning entire neurons affects model size and accuracy.