0
0
PyTorchml~5 mins

Model optimization (quantization, pruning) in PyTorch

Choose your learning style9 modes available
Introduction

Model optimization helps make machine learning models smaller and faster. This is useful to run models on devices with less power or memory.

When you want to run a model on a smartphone or small device.
When you need faster predictions without changing the model's accuracy much.
When you want to save storage space for your model files.
When deploying models in places with limited internet or hardware.
When reducing energy use is important, like in battery-powered devices.
Syntax
PyTorch
import torch
import torch.nn.utils.prune as prune

# Pruning example
prune.l1_unstructured(module, name='weight', amount=0.3)

# Quantization example
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
torch.quantization.prepare(model, inplace=True)
torch.quantization.convert(model, inplace=True)

Pruning removes less important weights to make the model smaller.

Quantization reduces the number size (like from 32-bit to 8-bit) to speed up the model.

Examples
This prunes 20% of the smallest weights in layer1 using L1 norm.
PyTorch
prune.l1_unstructured(model.layer1, name='weight', amount=0.2)
This removes the pruning reparameterization and makes pruning permanent.
PyTorch
prune.remove(model.layer1, 'weight')
This prepares and converts the model for 8-bit quantization using the 'fbgemm' backend.
PyTorch
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
torch.quantization.prepare(model, inplace=True)
torch.quantization.convert(model, inplace=True)
Sample Model

This code shows how to prune 40% of the weights in a simple linear layer and then quantize the model to 8-bit. It prints the number of weights before and after pruning and shows the model output after quantization.

PyTorch
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import torch.quantization

# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(10, 5)
    def forward(self, x):
        return self.fc(x)

model = SimpleModel()

# Before pruning: count nonzero weights
initial_nonzero = torch.count_nonzero(model.fc.weight).item()

# Prune 40% of weights in fc layer
prune.l1_unstructured(model.fc, name='weight', amount=0.4)

# Remove pruning reparameterization to make pruning permanent
prune.remove(model.fc, 'weight')

# Count nonzero weights after pruning
pruned_nonzero = torch.count_nonzero(model.fc.weight).item()

# Prepare model for quantization
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
torch.quantization.prepare(model, inplace=True)

# Fake calibration with random data
model(torch.randn(1, 10))

# Convert to quantized model
torch.quantization.convert(model, inplace=True)

# Check model output
input_data = torch.randn(1, 10)
output = model(input_data)

print(f"Initial nonzero weights: {initial_nonzero}")
print(f"Nonzero weights after pruning: {pruned_nonzero}")
print(f"Model output after quantization: {output}")
OutputSuccess
Important Notes

Pruning can reduce model size but might slightly reduce accuracy.

Quantization works best after some calibration with real data.

Always test your model after optimization to check performance.

Summary

Model optimization makes models smaller and faster.

Pruning removes less important weights.

Quantization reduces number precision to speed up models.