import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import torch.quantization
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# Define a simple fully connected network for MNIST
class Net(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(28*28, 300)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(300, 100)
self.relu2 = nn.ReLU()
self.fc3 = nn.Linear(100, 10)
def forward(self, x):
x = x.view(-1, 28*28)
x = self.relu1(self.fc1(x))
x = self.relu2(self.fc2(x))
x = self.fc3(x)
return x
# Load pretrained model (simulate pretrained weights)
model = Net()
# Normally load state_dict here, but for demo we train briefly
# Prepare data
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
val_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1000)
# Train briefly to simulate pretrained model
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
model.train()
for epoch in range(3):
for images, labels in train_loader:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# Evaluate function
def evaluate(model, loader):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in loader:
outputs = model(images)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
return 100 * correct / total
# Current accuracy
val_acc_before = evaluate(model, val_loader)
# Apply global unstructured pruning (20%)
parameters_to_prune = [
(model.fc1, 'weight'),
(model.fc2, 'weight'),
(model.fc3, 'weight')
]
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.2,
)
# Remove pruning reparameterization to make pruning permanent
for module, param in parameters_to_prune:
prune.remove(module, 'weight')
# Fine-tune pruned model briefly
model.train()
for epoch in range(1):
for images, labels in train_loader:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
val_acc_after_pruning = evaluate(model, val_loader)
# Apply dynamic quantization
model_quantized = torch.quantization.quantize_dynamic(
model, {nn.Linear}, dtype=torch.qint8
)
val_acc_after_quant = evaluate(model_quantized, val_loader)
# Measure model size
import io
buffer = io.BytesIO()
torch.save(model.state_dict(), buffer)
size_before = buffer.getbuffer().nbytes / 1024
buffer = io.BytesIO()
torch.save(model_quantized.state_dict(), buffer)
size_after = buffer.getbuffer().nbytes / 1024
# Measure inference time
import time
import numpy as np
model_quantized.eval()
images, _ = next(iter(val_loader))
images = images[:100]
start = time.time()
with torch.no_grad():
for _ in range(10):
_ = model_quantized(images)
end = time.time()
inference_time = (end - start) / 10 / images.size(0) * 1000 # ms per image
code_output = f"""
Validation accuracy before optimization: {val_acc_before:.2f}%
Validation accuracy after pruning: {val_acc_after_pruning:.2f}%
Validation accuracy after quantization: {val_acc_after_quant:.2f}%
Model size before optimization: {size_before:.2f} KB
Model size after quantization: {size_after:.2f} KB
Inference time per image after quantization: {inference_time:.2f} ms
"""
print(code_output)