0
0
PyTorchml~5 mins

ONNX export in PyTorch

Choose your learning style9 modes available
Introduction

ONNX export lets you save your PyTorch model in a format that many other tools and platforms can use. This helps share and run your model outside PyTorch easily.

You want to run your PyTorch model on a different system that does not support PyTorch.
You want to optimize your model for faster inference using ONNX runtime or other accelerators.
You want to deploy your model to mobile or web platforms that support ONNX.
You want to convert your model to another framework like TensorFlow or Caffe2.
You want to share your model with others who use different machine learning tools.
Syntax
PyTorch
torch.onnx.export(model, args, f, export_params=True, opset_version=None, do_constant_folding=True, input_names=None, output_names=None, dynamic_axes=None)

model: Your trained PyTorch model.

args: Example input tensor(s) to trace the model.

f: File path to save the ONNX model.

Examples
Export model with default settings using a dummy input tensor.
PyTorch
torch.onnx.export(model, dummy_input, "model.onnx")
Specify names for input and output nodes for clarity.
PyTorch
torch.onnx.export(model, dummy_input, "model.onnx", input_names=["input"], output_names=["output"])
Export with opset version 11 and dynamic batch size support.
PyTorch
torch.onnx.export(model, dummy_input, "model.onnx", opset_version=11, dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}})
Sample Model

This code defines a small linear model, creates a dummy input, and exports the model to ONNX format with named inputs and outputs. It also supports dynamic batch size.

PyTorch
import torch
import torch.nn as nn

# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(3, 2)
    def forward(self, x):
        return self.linear(x)

model = SimpleModel()
model.eval()  # Set to evaluation mode

dummy_input = torch.randn(1, 3)  # Example input

# Export the model to ONNX format
torch.onnx.export(
    model,
    dummy_input,
    "simple_model.onnx",
    input_names=["input"],
    output_names=["output"],
    opset_version=11,
    dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}
)

print("ONNX model exported successfully to simple_model.onnx")
OutputSuccess
Important Notes

Make sure your model is in evaluation mode (model.eval()) before exporting to avoid training-only behaviors like dropout.

Use a dummy input tensor with the correct shape to trace the model correctly.

Choosing the right opset_version ensures compatibility with the ONNX runtime or other tools you plan to use.

Summary

ONNX export saves PyTorch models in a universal format for sharing and deployment.

Use dummy inputs and set model.eval() before exporting.

You can customize input/output names and support dynamic shapes during export.