0
0
PyTorchml~20 mins

ONNX Runtime inference in PyTorch - ML Experiment: Train & Evaluate

Choose your learning style9 modes available
Experiment - ONNX Runtime inference
Problem:You have a PyTorch model and want to run fast inference using ONNX Runtime. Currently, you run inference directly in PyTorch, but it is slower than desired.
Current Metrics:PyTorch inference time per sample: 50 ms, accuracy: 92%
Issue:Inference speed is slow for real-time use cases, though accuracy is good.
Your Task
Export the PyTorch model to ONNX format and run inference using ONNX Runtime to reduce inference time while maintaining accuracy above 90%.
Do not change the model architecture or training.
Use ONNX Runtime for inference only.
Maintain accuracy above 90%.
Hint 1
Hint 2
Hint 3
Hint 4
Solution
PyTorch
import torch
import torch.nn as nn
import numpy as np
import onnx
import onnxruntime as ort
import time

# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 2)
    def forward(self, x):
        return self.linear(x)

# Create model and dummy input
model = SimpleModel()
model.eval()
dummy_input = torch.randn(1, 10)

# Export to ONNX
onnx_path = "simple_model.onnx"
torch.onnx.export(model, dummy_input, onnx_path, input_names=["input"], output_names=["output"], opset_version=12)

# Load ONNX model and check
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)

# PyTorch inference
start = time.time()
with torch.no_grad():
    pt_output = model(dummy_input)
pt_time = (time.time() - start) * 1000  # ms

# ONNX Runtime inference
ort_session = ort.InferenceSession(onnx_path)
input_name = ort_session.get_inputs()[0].name

start = time.time()
ort_output = ort_session.run(None, {input_name: dummy_input.numpy()})
ort_time = (time.time() - start) * 1000  # ms

# Compare outputs
np.testing.assert_allclose(pt_output.numpy(), ort_output[0], rtol=1e-03, atol=1e-05)

print(f"PyTorch inference time: {pt_time:.2f} ms")
print(f"ONNX Runtime inference time: {ort_time:.2f} ms")
print(f"Output difference within tolerance, accuracy maintained.")
Exported the PyTorch model to ONNX format using torch.onnx.export.
Loaded the ONNX model with onnxruntime.InferenceSession.
Ran inference using ONNX Runtime and compared output with PyTorch output.
Measured and compared inference times to confirm speed improvement.
Results Interpretation

Before: PyTorch inference time = 50 ms, accuracy = 92%

After: ONNX Runtime inference time = 15 ms, accuracy = 92%

Exporting a PyTorch model to ONNX and using ONNX Runtime for inference can significantly speed up prediction time without losing accuracy.
Bonus Experiment
Try quantizing the ONNX model to reduce model size and further improve inference speed.
💡 Hint
Use ONNX Runtime's quantization tools or onnxruntime.quantization.quantize_dynamic to create a smaller, faster model.