torch.onnx#
Created On: Jun 10, 2025 | Last Updated On: Sep 10, 2025
Overview#
Open Neural Network eXchange (ONNX) is an open standardformat for representing machine learning models. Thetorch.onnx module captures the computation graph from anative PyTorchtorch.nn.Module model and converts it into anONNX graph.
The exported model can be consumed by any of the manyruntimes that support ONNX, includingMicrosoft’sONNX Runtime.
Next example shows how to export a simple model.
importtorchclassMyModel(torch.nn.Module):def__init__(self):super(MyModel,self).__init__()self.conv1=torch.nn.Conv2d(1,128,5)defforward(self,x):returntorch.relu(self.conv1(x))input_tensor=torch.rand((1,1,128,128),dtype=torch.float32)model=MyModel()torch.onnx.export(model,# model to export(input_tensor,),# inputs of the model,"my_model.onnx",# filename of the ONNX modelinput_names=["input"],# Rename inputs for the ONNX modeldynamo=True# True or False to select the exporter to use)
torch.export-based ONNX Exporter#
The torch.export-based ONNX exporter is the newest exporter for PyTorch 2.6 and newer
torch.export engine is leveraged to produce a traced graph representing only the Tensor computation of the function in anAhead-of-Time (AOT) fashion. The resulting traced graph (1) produces normalized operators in the functionalATen operator set (as well as any user-specified custom operators), (2) has eliminated all Python controlflow and data structures (with certain exceptions), and (3) records the set of shape constraints needed toshow that this normalization and control-flow elimination is sound for future inputs, before it is finallytranslated into an ONNX graph.
Frequently Asked Questions#
Q: I have exported my LLM model, but its input size seems to be fixed?
The tracer records the shapes of the example inputs. If the model should acceptinputs of dynamic shapes, setdynamic_shapes when callingtorch.onnx.export().
Q: How to export models containing loops?
Seetorch.cond.
Contributing / Developing#
The ONNX exporter is a community project and we welcome contributions. We follow thePyTorch guidelines for contributions, but you mightalso be interested in reading ourdevelopment wiki.
torch.onnx APIs#
Functions#
- torch.onnx.export(model,args=(),f=None,*,kwargs=None,verbose=None,input_names=None,output_names=None,opset_version=None,dynamo=True,external_data=True,dynamic_shapes=None,custom_translation_table=None,report=False,optimize=True,verify=False,profile=False,dump_exported_program=False,artifacts_dir='.',fallback=False,export_params=True,keep_initializers_as_inputs=False,dynamic_axes=None,training=<TrainingMode.EVAL:0>,operator_export_type=<OperatorExportTypes.ONNX:0>,do_constant_folding=True,custom_opsets=None,export_modules_as_functions=False,autograd_inlining=True)[source]
Exports a model into ONNX format.
Setting
dynamo=Trueenables the new ONNX export logicwhich is based ontorch.export.ExportedProgramand a more modernset of translation logic. This is the recommended and default way to export modelsto ONNX.When
dynamo=True:The exporter tries the following strategies to get an ExportedProgram for conversion to ONNX.
If the model is already an ExportedProgram, it will be used as-is.
Use
torch.export.export()and setstrict=False.Use
torch.export.export()and setstrict=True.
- Parameters:
model (torch.nn.Module |torch.export.ExportedProgram |torch.jit.ScriptModule |torch.jit.ScriptFunction) – The model to be exported.
args (tuple[Any,...]) – Example positional inputs. Any non-Tensor arguments will be hard-coded into theexported model; any Tensor arguments will become inputs of the exported model,in the order they occur in the tuple.
f (str |os.PathLike |None) – Path to the output ONNX model file. E.g. “model.onnx”. This argument is kept forbackward compatibility. It is recommended to leave unspecified (None)and use the returned
torch.onnx.ONNXProgramto serialize the modelto a file instead.kwargs (dict[str,Any]|None) – Optional example keyword inputs.
verbose (bool |None) – Whether to enable verbose logging.
input_names (Sequence[str]|None) – names to assign to the input nodes of the graph, in order.
output_names (Sequence[str]|None) – names to assign to the output nodes of the graph, in order.
opset_version (int |None) – The version of thedefault (ai.onnx) opsetto target. You should set
opset_versionaccording to the supported opset versionsof the runtime backend or compiler you want to run the exported model with.Leave as default (None) to use the recommended version, or refer tothe ONNX operators documentation for more information.dynamo (bool) – Whether to export the model with
torch.exportExportedProgram instead of TorchScript.external_data (bool) – Whether to save the model weights as an external data file.This is required for models with large weights that exceed the ONNX file size limit (2GB).When False, the weights are saved in the ONNX file with the model architecture.
dynamic_shapes (dict[str,Any]|tuple[Any,...]|list[Any]|None) – A dictionary or a tuple of dynamic shapes for the model inputs. Refer to
torch.export.export()for more details. This is only used (and preferred) when dynamo is True.Note that dynamic_shapes is designed to be used when the model is exported with dynamo=True, whiledynamic_axes is used when dynamo=False.custom_translation_table (dict[Callable,Callable |Sequence[Callable]]|None) – A dictionary of custom decompositions for operators in the model.The dictionary should have the callable target in the fx Node as the key (e.g.
torch.ops.aten.stft.default),and the value should be a function that builds that graph using ONNX Script. This optionis only valid when dynamo is True.report (bool) – Whether to generate a markdown report for the export process. This optionis only valid when dynamo is True.
optimize (bool) – Whether to optimize the exported model. This optionis only valid when dynamo is True. Default is True.
verify (bool) – Whether to verify the exported model using ONNX Runtime. This optionis only valid when dynamo is True.
profile (bool) – Whether to profile the export process. This optionis only valid when dynamo is True.
dump_exported_program (bool) – Whether to dump the
torch.export.ExportedProgramto a file.This is useful for debugging the exporter. This option is only valid when dynamo is True.artifacts_dir (str |os.PathLike) – The directory to save the debugging artifacts like the report and the serializedexported program. This option is only valid when dynamo is True.
fallback (bool) – Whether to fallback to the TorchScript exporter if the dynamo exporter fails.This option is only valid when dynamo is True. When fallback is enabled, It isrecommended to set dynamic_axes even when dynamic_shapes is provided.
export_params (bool) –
When ``f`` is specified: If false, parameters (weights) will not be exported.
You can also leave it unspecified and use the returned
torch.onnx.ONNXProgramto control how initializers are treated when serializing the model.keep_initializers_as_inputs (bool) –
When ``f`` is specified: If True, all theinitializers (typically corresponding to model weights) in theexported graph will also be added as inputs to the graph. If False,then initializers are not added as inputs to the graph, and onlythe user inputs are added as inputs.
Set this to True if you intend to supply model weights at runtime.Set it to False if the weights are static to allow for better optimizations(e.g. constant folding) by backends/runtimes.
You can also leave it unspecified and use the returned
torch.onnx.ONNXProgramto control how initializers are treated when serializing the model.dynamic_axes (Mapping[str,Mapping[int,str]]|Mapping[str,Sequence[int]]|None) –
Prefer specifying
dynamic_shapeswhendynamo=Trueand whenfallbackis not enabled.By default the exported model will have the shapes of all input and output tensorsset to exactly match those given in
args. To specify axes of tensors asdynamic (i.e. known only at run-time), setdynamic_axesto a dict with schema:- KEY (str): an input or output name. Each name must also be provided in
input_namesor output_names.
- KEY (str): an input or output name. Each name must also be provided in
- VALUE (dict or list): If a dict, keys are axis indices and values are axis names. If a
list, each element is an axis index.
For example:
classSumModule(torch.nn.Module):defforward(self,x):returntorch.sum(x,dim=1)torch.onnx.export(SumModule(),(torch.ones(2,2),),"onnx.pb",input_names=["x"],output_names=["sum"],)
Produces:
input{name:"x"...shape{dim{dim_value:2# axis 0}dim{dim_value:2# axis 1...output{name:"sum"...shape{dim{dim_value:2# axis 0...
While:
torch.onnx.export(SumModule(),(torch.ones(2,2),),"onnx.pb",input_names=["x"],output_names=["sum"],dynamic_axes={# dict value: manually named axes"x":{0:"my_custom_axis_name"},# list value: automatic names"sum":[0],},)
Produces:
input{name:"x"...shape{dim{dim_param:"my_custom_axis_name"# axis 0}dim{dim_value:2# axis 1...output{name:"sum"...shape{dim{dim_param:"sum_dynamic_axes_1"# axis 0...
training (_C_onnx.TrainingMode) – Deprecated option. Instead, set the training mode of the model before exporting.
operator_export_type (_C_onnx.OperatorExportTypes) – Deprecated option. Only ONNX is supported.
do_constant_folding (bool) – Deprecated option.
export_modules_as_functions (bool |Collection[type[torch.nn.Module]]) – Deprecated option.
autograd_inlining (bool) – Deprecated option.
- Returns:
torch.onnx.ONNXProgramif dynamo is True, otherwise None.- Return type:
ONNXProgram | None
Changed in version 2.6:training is now deprecated. Instead, set the training mode of the model before exporting.operator_export_type is now deprecated. Only ONNX is supported.do_constant_folding is now deprecated. It is always enabled.export_modules_as_functions is now deprecated.autograd_inlining is now deprecated.
Changed in version 2.7:optimize is now True by default.
Changed in version 2.9:dynamo is now True by default.
Classes#
- classtorch.onnx.ONNXProgram(model,exported_program)
A class to represent an ONNX program that is callable with torch tensors.
- Variables:
model – The ONNX model as an ONNX IR model object.
exported_program – The exported program that produced the ONNX model.
- classtorch.onnx.OnnxExporterError
Errors raised by the ONNX exporter. This is the base class for all exporter errors.
Deprecated APIs#
Deprecated since version 2.6:These functions are deprecated and will be removed in a future version.
- torch.onnx.register_custom_op_symbolic(symbolic_name,symbolic_fn,opset_version)[source]#
Registers a symbolic function for a custom operator.
When the user registers symbolic for custom/contrib ops,it is highly recommended to add shape inference for that operator via setType API,otherwise the exported graph may have incorrect shape inference in some extreme cases.An example of setType istest_aten_embedding_2 intest_operators.py.
See “Custom Operators” in the module documentation for an example usage.
- Parameters:
symbolic_name (str) – The name of the custom operator in “<domain>::<op>”format.
symbolic_fn (Callable) – A function that takes in the ONNX graph andthe input arguments to the current operator, and returns newoperator nodes to add to the graph.
opset_version (int) – The ONNX opset version in which to register.
- torch.onnx.unregister_custom_op_symbolic(symbolic_name,opset_version)[source]#
Unregisters
symbolic_name.See “Custom Operators” in the module documentation for an example usage.
- torch.onnx.select_model_mode_for_export(model,mode)[source]#
A context manager to temporarily set the training mode of
modeltomode, resetting it when we exit the with-block.Deprecated since version 2.7:Please set training mode before exporting the model.