Model optimization helps make machine learning models smaller and faster. This is useful to run models on devices with less power or memory.
0
0
Model optimization (quantization, pruning) in PyTorch
Introduction
When you want to run a model on a smartphone or small device.
When you need faster predictions without changing the model's accuracy much.
When you want to save storage space for your model files.
When deploying models in places with limited internet or hardware.
When reducing energy use is important, like in battery-powered devices.
Syntax
PyTorch
import torch import torch.nn.utils.prune as prune # Pruning example prune.l1_unstructured(module, name='weight', amount=0.3) # Quantization example model.qconfig = torch.quantization.get_default_qconfig('fbgemm') torch.quantization.prepare(model, inplace=True) torch.quantization.convert(model, inplace=True)
Pruning removes less important weights to make the model smaller.
Quantization reduces the number size (like from 32-bit to 8-bit) to speed up the model.
Examples
This prunes 20% of the smallest weights in layer1 using L1 norm.
PyTorch
prune.l1_unstructured(model.layer1, name='weight', amount=0.2)
This removes the pruning reparameterization and makes pruning permanent.
PyTorch
prune.remove(model.layer1, 'weight')This prepares and converts the model for 8-bit quantization using the 'fbgemm' backend.
PyTorch
model.qconfig = torch.quantization.get_default_qconfig('fbgemm') torch.quantization.prepare(model, inplace=True) torch.quantization.convert(model, inplace=True)
Sample Model
This code shows how to prune 40% of the weights in a simple linear layer and then quantize the model to 8-bit. It prints the number of weights before and after pruning and shows the model output after quantization.
PyTorch
import torch import torch.nn as nn import torch.nn.utils.prune as prune import torch.quantization # Define a simple model class SimpleModel(nn.Module): def __init__(self): super().__init__() self.fc = nn.Linear(10, 5) def forward(self, x): return self.fc(x) model = SimpleModel() # Before pruning: count nonzero weights initial_nonzero = torch.count_nonzero(model.fc.weight).item() # Prune 40% of weights in fc layer prune.l1_unstructured(model.fc, name='weight', amount=0.4) # Remove pruning reparameterization to make pruning permanent prune.remove(model.fc, 'weight') # Count nonzero weights after pruning pruned_nonzero = torch.count_nonzero(model.fc.weight).item() # Prepare model for quantization model.qconfig = torch.quantization.get_default_qconfig('fbgemm') torch.quantization.prepare(model, inplace=True) # Fake calibration with random data model(torch.randn(1, 10)) # Convert to quantized model torch.quantization.convert(model, inplace=True) # Check model output input_data = torch.randn(1, 10) output = model(input_data) print(f"Initial nonzero weights: {initial_nonzero}") print(f"Nonzero weights after pruning: {pruned_nonzero}") print(f"Model output after quantization: {output}")
OutputSuccess
Important Notes
Pruning can reduce model size but might slightly reduce accuracy.
Quantization works best after some calibration with real data.
Always test your model after optimization to check performance.
Summary
Model optimization makes models smaller and faster.
Pruning removes less important weights.
Quantization reduces number precision to speed up models.