
Contents
More
ONNX defines a long list of operators used in machine learning models.They are used to implement functions. This step is usually taken care ofby converting libraries:sklearn-onnx forscikit-learn,torch.onnx forpytorch,tensorflow-onnx fortensorflow.Bothtorch.onnx andtensorflow-onnx converts any function expressedwith the available function in those packages and that works becausethere is usually no need to mix packages.But in some occasions, there is a need to directly write functions with theonnx syntax.scikit-learn is implemented withnumpy and thereis no converter from numpy to onnx. Sometimes, it is needed to extendan existing onnx models or to merge models coming from different packages.Sometimes, they are just not available, only onnx is.Let’s see how it looks like with a very simply example.
For example, the well known Euclidian distance can be expressed with numpy as follows:
importnumpyasnpdefeuclidan(X:np.array,Y:np.array)->float:return((X-Y)**2).sum()
The mathematical function must first be translated withONNX Operators orprimitives. It is usually easy because the primitives are very close to whatnumpy defines. It can be expressed as (the syntax is just for illustration).
importonnxonnx-defeuclidian(X:onnx.TensorProto[FLOAT],X:onnx.TensorProto[FLOAT])->onnx.FLOAT:dxy=onnx.Sub(X,Y)sxy=onnx.Pow(dxy,2)d=onnx.ReduceSum(sxy)returnd
This example is short but does not work as it is.Theinner API defined inonnx.helper is quite verbose andthe true implementation would be the following.
<<<
importonnximportonnx.helperasohdefmake_euclidean(input_names:tuple[str]=("X","Y"),output_name:str="Z",elem_type:int=onnx.TensorProto.FLOAT,opset:int|None=None,)->onnx.ModelProto:ifopsetisNone:opset=onnx.defs.onnx_opset_version()X=oh.make_tensor_value_info(input_names[0],elem_type,None)Y=oh.make_tensor_value_info(input_names[1],elem_type,None)Z=oh.make_tensor_value_info(output_name,elem_type,None)two=oh.make_tensor("two",onnx.TensorProto.INT64,[1],[2])n1=oh.make_node("Sub",["X","Y"],["dxy"])n2=oh.make_node("Pow",["dxy","two"],["dxy2"])n3=oh.make_node("ReduceSum",["dxy2"],[output_name])graph=oh.make_graph([n1,n2,n3],"euclidian",[X,Y],[Z],[two])model=oh.make_model(graph,opset_imports=[oh.make_opsetid("",opset)],ir_version=9,)returnmodelmodel=make_euclidean()print(model)
>>>
ir_version:9graph{node{input:"X"input:"Y"output:"dxy"op_type:"Sub"}node{input:"dxy"input:"two"output:"dxy2"op_type:"Pow"}node{input:"dxy2"output:"Z"op_type:"ReduceSum"}name:"euclidian"initializer{dims:1data_type:7int64_data:2name:"two"}input{name:"X"type{tensor_type{elem_type:1}}}input{name:"Y"type{tensor_type{elem_type:1}}}output{name:"Z"type{tensor_type{elem_type:1}}}}opset_import{domain:""version:23}
Since it is a second implementation of an existing function, it is necessary tocheck the output is the same.
<<<
importnumpyasnpfromnumpy.testingimportassert_allclosefromonnx.referenceimportReferenceEvaluatorfromonnx_array_api.ext_test_caseimportExtTestCase# This is the same function.fromonnx_array_api.validation.docsimportmake_euclideandeftest_make_euclidean():model=make_euclidean()ref=ReferenceEvaluator(model)X=np.random.rand(3,4).astype(np.float32)Y=np.random.rand(3,4).astype(np.float32)expected=((X-Y)**2).sum(keepdims=1)got=ref.run(None,{"X":X,"Y":Y})[0]assert_allclose(expected,got,atol=1e-6)test_make_euclidean()
>>>
But the reference implementation in onnx is not the runtime used to deploy the model.A second unit test must be added to check this one as well.
<<<
importnumpyasnpfromnumpy.testingimportassert_allclosefromonnx_array_api.ext_test_caseimportExtTestCase# This is the same function.fromonnx_array_api.validation.docsimportmake_euclideandeftest_make_euclidean_ort():fromonnxruntimeimportInferenceSessionmodel=make_euclidean()ref=InferenceSession(model.SerializeToString(),providers=["CPUExecutionProvider"])X=np.random.rand(3,4).astype(np.float32)Y=np.random.rand(3,4).astype(np.float32)expected=((X-Y)**2).sum(keepdims=1)got=ref.run(None,{"X":X,"Y":Y})[0]assert_allclose(expected,got,atol=1e-6)try:test_make_euclidean_ort()exceptExceptionase:print(e)
>>>
[ONNXRuntimeError]:2:INVALID_ARGUMENT:Failedtoloadmodelwitherror:/home/xadupre/github/onnxruntime/onnxruntime/core/graph/model_load_utils.h:46voidonnxruntime::model_load_utils::ValidateOpsetForDomain(conststd::unordered_map<std::__cxx11::basic_string<char>,int>&,constonnxruntime::logging::Logger&,bool,conststring&,int)ONNXRuntimeonly*guarantees*supportformodelsstampedwithofficialreleasedonnxopsetversions.Opset23isunderdevelopmentandsupportforthisislimited.TheoperatorschemasandorotherfunctionalitymaychangebeforenextONNXreleaseandinthiscaseONNXRuntimewillnotguaranteebackwardcompatibility.Currentofficialsupportfordomainai.onnxistillopset21.
The list of operators is constantly evolving: onnx is versioned.The function may fail because the model says it is using a versiona runtime does not support. Let’s change it.
<<<
importnumpyasnpfromnumpy.testingimportassert_allclosefromonnx_array_api.ext_test_caseimportExtTestCase# This is the same function.fromonnx_array_api.validation.docsimportmake_euclideandeftest_make_euclidean_ort():fromonnxruntimeimportInferenceSession# opset=18: it uses the opset version 18, this number# is incremented at every minor release.model=make_euclidean(opset=18)ref=InferenceSession(model.SerializeToString(),providers=["CPUExecutionProvider"])X=np.random.rand(3,4).astype(np.float32)Y=np.random.rand(3,4).astype(np.float32)expected=((X-Y)**2).sum(keepdims=1)got=ref.run(None,{"X":X,"Y":Y})[0]assert_allclose(expected,got,atol=1e-6)test_make_euclidean_ort()
>>>
But the runtime must support many versions and the unit tests may look likethe following:
<<<
importnumpyasnpfromnumpy.testingimportassert_allcloseimportonnx.defsfromonnx_array_api.ext_test_caseimportExtTestCase# This is the same function.fromonnx_array_api.validation.docsimportmake_euclideandeftest_make_euclidean_ort():fromonnxruntimeimportInferenceSession# opset=18: it uses the opset version 18, this number# is incremented at every minor release.X=np.random.rand(3,4).astype(np.float32)Y=np.random.rand(3,4).astype(np.float32)expected=((X-Y)**2).sum(keepdims=1)foropsetinrange(6,onnx.defs.onnx_opset_version()-1):model=make_euclidean(opset=opset)try:ref=InferenceSession(model.SerializeToString(),providers=["CPUExecutionProvider"])got=ref.run(None,{"X":X,"Y":Y})[0]exceptExceptionase:print(f"fail opset={opset}",e)ifopset<18:continueraiseeassert_allclose(expected,got,atol=1e-6)test_make_euclidean_ort()
>>>
failopset=6[ONNXRuntimeError]:10:INVALID_GRAPH:Thisisaninvalidmodel.TypeError:Type'tensor(int64)'ofinputparameter(two)ofoperator(Pow)innode()isinvalid.failopset=7[ONNXRuntimeError]:10:INVALID_GRAPH:Thisisaninvalidmodel.TypeError:Type'tensor(int64)'ofinputparameter(two)ofoperator(Pow)innode()isinvalid.failopset=8[ONNXRuntimeError]:10:INVALID_GRAPH:Thisisaninvalidmodel.TypeError:Type'tensor(int64)'ofinputparameter(two)ofoperator(Pow)innode()isinvalid.failopset=9[ONNXRuntimeError]:10:INVALID_GRAPH:Thisisaninvalidmodel.TypeError:Type'tensor(int64)'ofinputparameter(two)ofoperator(Pow)innode()isinvalid.failopset=10[ONNXRuntimeError]:10:INVALID_GRAPH:Thisisaninvalidmodel.TypeError:Type'tensor(int64)'ofinputparameter(two)ofoperator(Pow)innode()isinvalid.failopset=11[ONNXRuntimeError]:10:INVALID_GRAPH:Thisisaninvalidmodel.TypeError:Type'tensor(int64)'ofinputparameter(two)ofoperator(Pow)innode()isinvalid.
This work is quite long even for a simple function. For a longer one,due to the verbosity of the inner API, it is quite difficult to writethe correct implementation on the first try. The unit test cannot be avoided.The inner API is usually enough when the translation from python to onnxdoes not happen often. When it is, almost every library implementsits own simplified way to create onnx graphs and because creating its ownAPI is not difficult, many times, the decision was made to create a new onerather than using an existing one.
Many existing options are available to write custom onnx graphs.The development is usually driven by what they are used for. Each of themmay not fully support all your needs and it is not always easy to understandthe error messages they provide when something goes wrong.It is better to understand its own need before choosing one.Here are some of the questions which may need to be answered.
ability to easily write loops and tests (control flow)
ability to debug (eager mode)
ability to use the same function to produce different implementationsbased on the same version
ability to interact with other frameworks
ability to merge existing onnx graph
ability to describe an existing graph with this API
ability to easily define constants
ability to handle multiple domains
ability to support local functions
easy error messages
is it actively maintained?
pytorch offers the possibility to convert any functionimplemented with pytorch function into onnx withtorch.onnx.A couple of examples.
importtorchimporttorch.nnclassMyModel(torch.nn.Module):def__init__(self)->None:super().__init__()self.linear=torch.nn.Linear(2,2)defforward(self,x,bias=None):out=self.linear(x)out=out+biasreturnoutmodel=MyModel()kwargs={"bias":3.}inputs=(torch.randn(2,2,2),)export_output=torch.onnx.dynamo_export(model,inputs,**kwargs)export_output.save("my_simple_model.onnx")
fromtypingimportDict,Tupleimporttorchimporttorch.onnxdeffunc_with_nested_input_structure(x_dict:Dict[str,torch.Tensor],y_tuple:Tuple[torch.Tensor,Tuple[torch.Tensor,torch.Tensor]],):if"a"inx_dict:x=x_dict["a"]elif"b"inx_dict:x=x_dict["b"]else:x=torch.randn(3)y1,(y2,y3)=y_tuplereturnx+y1+y2+y3x_dict={"a":torch.tensor(1.)}y_tuple=(torch.tensor(2.),(torch.tensor(3.),torch.tensor(4.)))export_output=torch.onnx.dynamo_export(func_with_nested_input_structure,x_dict,y_tuple)print(export_output.adapt_torch_inputs_to_onnx(x_dict,y_tuple))
onnxscript is used inTorch Export to ONNX.It converts python code to onnx code by analyzing the python code(throughast). The package makes it very easy to use loops and tests in onnx.It is very close to onnx syntax. It is not easy to support multipleimplementation depending on the opset version required by the user.
Example taken from the documentation :
importonnx# We use ONNX opset 15 to define the function below.fromonnxscriptimportFLOATfromonnxscriptimportopset15asopfromonnxscriptimportscript# We use the script decorator to indicate that# this is meant to be translated to ONNX.@script()defonnx_hardmax(X,axis:int):"""Hardmax is similar to ArgMax, with the result being encoded OneHot style."""# The type annotation on X indicates that it is a float tensor of# unknown rank. The type annotation on axis indicates that it will# be treated as an int attribute in ONNX.## Invoke ONNX opset 15 op ArgMax.# Use unnamed arguments for ONNX input parameters, and named# arguments for ONNX attribute parameters.argmax=op.ArgMax(X,axis=axis,keepdims=False)xshape=op.Shape(X,start=axis)# use the Constant operator to create constant tensorszero=op.Constant(value_ints=[0])depth=op.GatherElements(xshape,zero)empty_shape=op.Constant(value_ints=[0])depth=op.Reshape(depth,empty_shape)values=op.Constant(value_ints=[0,1])cast_values=op.CastLike(values,X)returnop.OneHot(argmax,depth,cast_values,axis=axis)# We use the script decorator to indicate that# this is meant to be translated to ONNX.@script()defsample_model(X:FLOAT[64,128],Wt:FLOAT[128,10],Bias:FLOAT[10])->FLOAT[64,10]:matmul=op.MatMul(X,Wt)+Biasreturnonnx_hardmax(matmul,axis=1)# onnx_model is an in-memory ModelProtoonnx_model=sample_model.to_model_proto()# Save the ONNX model at a given pathonnx.save(onnx_model,"sample_model.onnx")# Check the modeltry:onnx.checker.check_model(onnx_model)exceptonnx.checker.ValidationErrorase:print(f"The model is invalid:{e}")else:print("The model is valid!")
An Eager mode is available to debug what the code does.
importnumpyasnpv=np.array([[0,1],[2,3]],dtype=np.float32)result=Hardmax(v)
The syntax ofspox is similar but it does not useast.Therefore,loops and testsare expressed in a very different way. The tricky part with it is to handlethe local context. A variable created in the main graph is known by anyof its subgraphs.
Example taken from the documentation :
importonnxfromspoximportargument,build,Tensor,Var# Import operators from the ai.onnx domain at version 17fromspox.opset.ai.onnximportv17asopdefgeometric_mean(x:Var,y:Var)->Var:# use the standard Sqrt and Mulreturnop.sqrt(op.mul(x,y))# Create typed model inputs. Each tensor is of rank 1# and has the runtime-determined length 'N'.a=argument(Tensor(float,('N',)))b=argument(Tensor(float,('N',)))# Perform operations on `Var`sc=geometric_mean(a,b)# Build an `onnx.ModelProto` for the given inputs and outputs.model:onnx.ModelProto=build(inputs={'a':a,'b':b},outputs={'c':c})
The function can be tested with a mechanism calledvalue propagation.
sklearn-onnx also implements its own API to add custom graphs.It was designed to shorten the time spent in reimplementingscikit-learncode intoonnx code. It can be used to implement a new convertermapped a custom model as described in this example:Implement a new converter.But it can also be used to build standalone models.
<<<
importnumpyasnpimportonnximportonnx.helperasohfromonnx_array_api.plotting.text_plotimportonnx_simple_text_plotdefmake_euclidean_skl2onnx(input_names:tuple[str]=("X","Y"),output_name:str="Z",elem_type:int=onnx.TensorProto.FLOAT,opset:int|None=None,)->onnx.ModelProto:ifopsetisNone:opset=onnx.defs.onnx_opset_version()fromskl2onnx.algebra.onnx_opsimportOnnxSub,OnnxPow,OnnxReduceSumdxy=OnnxSub(input_names[0],input_names[1],op_version=opset)dxy2=OnnxPow(dxy,np.array([2],dtype=np.int64),op_version=opset)final=OnnxReduceSum(dxy2,op_version=opset,output_names=[output_name])np_type=oh.tensor_dtype_to_np_dtype(elem_type)dummy=np.empty([1],np_type)returnfinal.to_onnx({"X":dummy,"Y":dummy})model=make_euclidean_skl2onnx()print(onnx_simple_text_plot(model))
>>>
opset:domain=''version=15input:name='X'type=dtype('float32')shape=['']input:name='Y'type=dtype('float32')shape=['']init:name='Po_Powcst'type=int64shape=(1,)--array([2])Sub(X,Y)->Su_C0Pow(Su_C0,Po_Powcst)->Po_Z0ReduceSum(Po_Z0)->Zoutput:name='Z'type=dtype('float32')shape=[1]
onnxblockswas introduced in onnxruntime to define custom losses in order to traina model withonnxruntime-training. It is mostly used for this usage.The syntax is similar to pytorch.
importonnxruntime.training.onnxblockasonnxblockfromonnxruntime.trainingimportartifacts# Define a custom loss block that takes in two inputs# and performs a weighted average of the losses from these# two inputs.classWeightedAverageLoss(onnxblock.Block):def__init__(self):self._loss1=onnxblock.loss.MSELoss()self._loss2=onnxblock.loss.MSELoss()self._w1=onnxblock.blocks.Constant(0.4)self._w2=onnxblock.blocks.Constant(0.6)self._add=onnxblock.blocks.Add()self._mul=onnxblock.blocks.Mul()defbuild(self,loss_input_name1,loss_input_name2):# The build method defines how the block should be stacked on top of# loss_input_name1 and loss_input_name2# Returns weighted average of the two lossesreturnself._add(self._mul(self._w1(),self._loss1(loss_input_name1,target_name="target1")),self._mul(self._w2(),self._loss2(loss_input_name2,target_name="target2")))my_custom_loss=WeightedAverageLoss()# Load the onnx modelmodel_path="model.onnx"base_model=onnx.load(model_path)# Define the parameters that need their gradient computedrequires_grad=["weight1","bias1","weight2","bias2"]frozen_params=["weight3","bias3"]# Now, we can invoke generate_artifacts with this custom loss functionartifacts.generate_artifacts(base_model,requires_grad=requires_grad,frozen_params=frozen_params,loss=my_custom_loss,optimizer=artifacts.OptimType.AdamW)# Successful completion of the above call will generate 4 files in the current working directory,# one for each of the artifacts mentioned above (training_model.onnx, eval_model.onnx, checkpoint, op)
onnx-graphsurgeon implements main classGraph which providesall the necessary method to add nodes, import existing onnx files.The following example is taken fromonnx-graphsurgeon/examples.The first part generates a graph.
importonnx_graphsurgeonasgsimportnumpyasnpimportonnx# Computes Y = x0 + (a * x1 + b)shape=(1,3,224,224)# Inputsx0=gs.Variable(name="x0",dtype=np.float32,shape=shape)x1=gs.Variable(name="x1",dtype=np.float32,shape=shape)# Intermediate tensorsa=gs.Constant("a",values=np.ones(shape=shape,dtype=np.float32))b=gs.Constant("b",values=np.ones(shape=shape,dtype=np.float32))mul_out=gs.Variable(name="mul_out")add_out=gs.Variable(name="add_out")# OutputsY=gs.Variable(name="Y",dtype=np.float32,shape=shape)nodes=[# mul_out = a * x1gs.Node(op="Mul",inputs=[a,x1],outputs=[mul_out]),# add_out = mul_out + bgs.Node(op="Add",inputs=[mul_out,b],outputs=[add_out]),# Y = x0 + addgs.Node(op="Add",inputs=[x0,add_out],outputs=[Y]),]graph=gs.Graph(nodes=nodes,inputs=[x0,x1],outputs=[Y])onnx.save(gs.export_onnx(graph),"model.onnx")
The second part modifies it.
importonnx_graphsurgeonasgsimportnumpyasnpimportonnxgraph=gs.import_onnx(onnx.load("model.onnx"))# 1. Remove the `b` input of the add nodefirst_add=[nodefornodeingraph.nodesifnode.op=="Add"][0]first_add.inputs=[inpforinpinfirst_add.inputsifinp.name!="b"]# 2. Change the Add to a LeakyRelufirst_add.op="LeakyRelu"first_add.attrs["alpha"]=0.02# 3. Add an identity after the add nodeidentity_out=gs.Variable("identity_out",dtype=np.float32)identity=gs.Node(op="Identity",inputs=first_add.outputs,outputs=[identity_out])graph.nodes.append(identity)# 4. Modify the graph output to be the identity outputgraph.outputs=[identity_out]# 5. Remove unused nodes/tensors, and topologically sort the graph# ONNX requires nodes to be topologically sorted to be considered valid.# Therefore, you should only need to sort the graph when you have added new nodes out-of-order.# In this case, the identity node is already in the correct spot (it is the last node,# and was appended to the end of the list), but to be on the safer side, we can sort anyway.graph.cleanup().toposort()onnx.save(gs.export_onnx(graph),"modified.onnx")
SeeGraphBuilder: common API for ONNX. This API is very similar to whatskl2onnx implements.It is still about adding nodes to a graph but some tasks are automated such asnaming the results or converting constants to onnx classes.
<<<
importnumpyasnpfromonnx_array_api.graph_apiimportGraphBuilderfromonnx_array_api.plotting.text_plotimportonnx_simple_text_plotg=GraphBuilder()g.make_tensor_input("X",np.float32,(None,None))g.make_tensor_input("Y",np.float32,(None,None))r1=g.op.Sub("X","Y")r2=g.op.Pow(r1,np.array([2],dtype=np.int64))g.op.ReduceSum(r2,outputs=["Z"])g.make_tensor_output("Z",np.float32,(None,None))onx=g.to_onnx()print(onnx_simple_text_plot(onx))
>>>
opset:domain=''version=22input:name='X'type=dtype('float32')shape=['','']input:name='Y'type=dtype('float32')shape=['','']init:name='cst'type=int64shape=(1,)--array([2])Sub(X,Y)->_onx_sub0Pow(_onx_sub0,cst)->_onx_pow0ReduceSum(_onx_pow0)->Zoutput:name='Z'type=dtype('float32')shape=['','']
SeeLight API for ONNX: everything in one line. This API was created to be able to write an onnx graphin one instruction. It is inspired from thereverse Polish notation.There is no eager mode.
<<<
importnumpyasnpfromonnx_array_api.light_apiimportstartfromonnx_array_api.plotting.text_plotimportonnx_simple_text_plotmodel=(start().vin("X").vin("Y").bring("X","Y").Sub().rename("dxy").cst(np.array([2],dtype=np.int64),"two").bring("dxy","two").Pow().ReduceSum().rename("Z").vout().to_onnx())print(onnx_simple_text_plot(model))
>>>
opset:domain=''version=22input:name='X'type=dtype('float32')shape=Noneinput:name='Y'type=dtype('float32')shape=Noneinit:name='two'type=int64shape=(1,)--array([2])Sub(X,Y)->dxyPow(dxy,two)->r1_0ReduceSum(r1_0,keepdims=1,noop_with_empty_axes=0)->Zoutput:name='Z'type=dtype('float32')shape=None
SeeNumpy API for ONNX. This API was introduced to create graphsby using numpy API. If a function is defined only with numpy,it should be possible to use the exact same code to create thecorresponding onnx graph. That’s what this API tries to achieve.It works with the exception of control flow. In that case, the functionproduces different onnx graphs depending on the execution path.
<<<
importnumpyasnpfromonnx_array_api.npximportjit_onnxfromonnx_array_api.plotting.text_plotimportonnx_simple_text_plotdefl2_loss(x,y):return((x-y)**2).sum(keepdims=1)jitted_myloss=jit_onnx(l2_loss)dummy=np.array([0],dtype=np.float32)# The function is executed. Only then a onnx graph is created.# One is created depending on the input type.jitted_myloss(dummy,dummy)# get_onnx only works if it was executed once or at least with# the same input typemodel=jitted_myloss.get_onnx()print(onnx_simple_text_plot(model))
>>>
opset:domain=''version=18input:name='x0'type=dtype('float32')shape=['']input:name='x1'type=dtype('float32')shape=['']Constant(value=2)->r__1Sub(x0,x1)->r__0CastLike(r__1,r__0)->r__2Pow(r__0,r__2)->r__3ReduceSum(r__3,keepdims=1)->r__4output:name='r__4'type=dtype('float32')shape=[1]