How to Export PyTorch Model to ONNX Format Easily
To export a PyTorch model to ONNX format, use
torch.onnx.export() by passing the model, a sample input tensor, and the output file path. This function converts the model into ONNX format, which can be used for interoperability with other frameworks.Syntax
The basic syntax of exporting a PyTorch model to ONNX is:
model: Your trained PyTorch model.args: A tuple or single tensor representing a sample input to the model.f: The file path where the ONNX model will be saved.export_params: Whether to export the trained parameters (usuallyTrue).opset_version: ONNX version to target (default is 9 or higher).input_namesandoutput_names: Optional names for inputs and outputs.
python
torch.onnx.export(model, args, f, export_params=True, opset_version=11, input_names=None, output_names=None)
Example
This example shows how to export a simple PyTorch model to ONNX format. It creates a model, defines a dummy input, and saves the ONNX file.
python
import torch import torch.nn as nn # Define a simple model class SimpleModel(nn.Module): def __init__(self): super(SimpleModel, self).__init__() self.linear = nn.Linear(3, 2) def forward(self, x): return self.linear(x) # Instantiate the model and set to eval mode model = SimpleModel() model.eval() # Create a dummy input tensor with the right shape dummy_input = torch.randn(1, 3) # Export the model to ONNX format onnx_path = "simple_model.onnx" torch.onnx.export( model, dummy_input, onnx_path, export_params=True, opset_version=11, input_names=["input"], output_names=["output"] ) print(f"Model exported to {onnx_path}")
Output
Model exported to simple_model.onnx
Common Pitfalls
- Not setting the model to eval mode: Always call
model.eval()before export to disable dropout and batch norm randomness. - Incorrect dummy input shape: The dummy input must match the model's expected input shape exactly.
- Missing opset version: Use a recent
opset_version(like 11 or higher) to support newer operators. - Exporting with training mode: This can cause inconsistent behavior in the exported model.
python
import torch import torch.nn as nn # Wrong: model in train mode model = nn.Linear(3, 2) # Missing eval() call dummy_input = torch.randn(1, 3) # This may export but cause issues later # Correct way: model.eval() # Set to evaluation mode # Then export as usual
Quick Reference
Remember these tips when exporting PyTorch models to ONNX:
- Always use
model.eval()before export. - Provide a dummy input tensor matching the model input shape.
- Set
opset_versionto 11 or higher for compatibility. - Name inputs and outputs for clarity.
- Check the exported ONNX file with tools like
onnxruntimeorNetron.
Key Takeaways
Use torch.onnx.export() with model, dummy input, and output file path to export PyTorch models to ONNX.
Always set your model to evaluation mode with model.eval() before exporting.
Ensure the dummy input tensor matches the model's expected input shape exactly.
Use a recent opset_version (11 or higher) for better operator support.
Name inputs and outputs to make the ONNX model easier to understand and debug.