0
0
PyTorchml~20 mins

ONNX export in PyTorch - ML Experiment: Train & Evaluate

Choose your learning style9 modes available
Experiment - ONNX export
Problem:You have a PyTorch model trained for digit classification. You want to export it to ONNX format to use it in other frameworks or deployment environments.
Current Metrics:Training accuracy: 98%, Validation accuracy: 96%, Model saved only in PyTorch format (.pt)
Issue:The model is only saved in PyTorch format, which limits interoperability with other tools that support ONNX. You need to export the model correctly to ONNX format.
Your Task
Export the given PyTorch model to ONNX format ensuring the exported model can be loaded and run with the same input shape.
Do not retrain the model.
Use the existing trained model weights.
Ensure the ONNX export includes input and output names.
Use a dummy input tensor with the correct shape for export.
Hint 1
Hint 2
Hint 3
Hint 4
Solution
PyTorch
import torch
import torch.nn as nn
import torch.onnx

# Define a simple model (for example purposes)
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(28*28, 10)
    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.fc(x)

# Load the trained model weights (simulate loading)
model = SimpleNet()
# Normally you would load weights here, e.g., model.load_state_dict(torch.load('model.pt'))
model.eval()

# Create dummy input with batch size 1 and image size 28x28
dummy_input = torch.randn(1, 1, 28, 28)

# Export the model to ONNX format
onnx_path = "simple_net.onnx"
torch.onnx.export(
    model,
    dummy_input,
    onnx_path,
    export_params=True,
    opset_version=11,
    do_constant_folding=True,
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}
)

print(f"Model exported to {onnx_path}")
Added torch.onnx.export call to save the PyTorch model in ONNX format.
Used a dummy input tensor with correct shape (1, 1, 28, 28) for export.
Specified input and output node names for clarity.
Set opset_version to 11 for compatibility.
Enabled dynamic axes for batch size to allow variable input sizes.
Results Interpretation

Before: Model saved only in PyTorch format (.pt), limiting interoperability.

After: Model exported to ONNX format with input/output names and dynamic batch size support, enabling use in other frameworks.

Exporting a PyTorch model to ONNX format allows you to use the model in different environments and frameworks, improving flexibility and deployment options.
Bonus Experiment
Load the exported ONNX model using onnxruntime and run inference on a sample input to verify correctness.
💡 Hint
Use onnxruntime.InferenceSession to load the model and run session.run with the input dictionary.