
Contents
More

onnx-array-api implements APIs to create custom ONNX graphs.The objective is to speed up the implementation of converter libraries.
Sources available ongithub/onnx-array-api.
Almost every converting library (converting a machine learned model to ONNX) is implementingits own graph builder and customizes it for its needs.It handles some frequent tasks such as giving names to intermediateresults, loading, saving onnx models. It can be used as well to extend an existing graph.SeeGraphBuilder: common API for ONNX.
<<<
importnumpyasnpfromonnx_array_api.graph_apiimportGraphBuilderfromonnx_array_api.plotting.text_plotimportonnx_simple_text_plotg=GraphBuilder()g.make_tensor_input("X",np.float32,(None,None))g.make_tensor_input("Y",np.float32,(None,None))r1=g.make_node("Sub",["X","Y"])# the name given to the output is given by the class,# it ensures the name is uniqueinit=g.make_initializer(np.array([2],dtype=np.int64))# the class automatically# converts the array to a tensorr2=g.make_node("Pow",[r1,init])g.make_node("ReduceSum",[r2],outputs=["Z"])# the output name is given because# the user wants to choose the nameg.make_tensor_output("Z",np.float32,(None,None))onx=g.to_onnx()# final conversion to onnxprint(onnx_simple_text_plot(onx))
>>>
opset:domain=''version=22input:name='X'type=dtype('float32')shape=['','']input:name='Y'type=dtype('float32')shape=['','']init:name='cst'type=int64shape=(1,)--array([2])Sub(X,Y)->_onx_sub0Pow(_onx_sub0,cst)->_onx_pow0ReduceSum(_onx_pow0)->Zoutput:name='Z'type=dtype('float32')shape=['','']
The syntax is inspired from theReverse Polish Notation.This kind of API is easy to use to build new graphs,less easy to extend an existing graph. SeeLight API for ONNX: everything in one line.
<<<
importnumpyasnpfromonnx_array_api.light_apiimportstartfromonnx_array_api.plotting.text_plotimportonnx_simple_text_plotmodel=(start().vin("X").vin("Y").bring("X","Y").Sub().rename("dxy").cst(np.array([2],dtype=np.int64),"two").bring("dxy","two").Pow().ReduceSum().rename("Z").vout().to_onnx())print(onnx_simple_text_plot(model))
>>>
opset:domain=''version=22input:name='X'type=dtype('float32')shape=Noneinput:name='Y'type=dtype('float32')shape=Noneinit:name='two'type=int64shape=(1,)--array([2])Sub(X,Y)->dxyPow(dxy,two)->r1_0ReduceSum(r1_0,keepdims=1,noop_with_empty_axes=0)->Zoutput:name='Z'type=dtype('float32')shape=None
Writing ONNX graphs requires to know ONNX syntax unlessit is possible to reuse an existing syntax such asnumpy.This is what this API is doing.This kind of API is easy to use to build new graphs,almost impossible to use to extend new graphs as it usually requiresto know onnx for that. SeeNumpy API for ONNX.
<<<
importnumpyasnp# Afromonnx_array_api.npximportabsolute,jit_onnxfromonnx_array_api.plotting.text_plotimportonnx_simple_text_plotdefl1_loss(x,y):returnabsolute(x-y).sum()defl2_loss(x,y):return((x-y)**2).sum()defmyloss(x,y):returnl1_loss(x[:,0],y[:,0])+l2_loss(x[:,1],y[:,1])jitted_myloss=jit_onnx(myloss)x=np.array([[0.1,0.2],[0.3,0.4]],dtype=np.float32)y=np.array([[0.11,0.22],[0.33,0.44]],dtype=np.float32)res=jitted_myloss(x,y)print(res)print(onnx_simple_text_plot(jitted_myloss.get_onnx()))
>>>
0.042opset:domain=''version=18input:name='x0'type=dtype('float32')shape=['','']input:name='x1'type=dtype('float32')shape=['','']Constant(value=[1])->cst__0Constant(value=[2])->cst__1Constant(value=[1])->cst__2Slice(x0,cst__0,cst__1,cst__2)->r__12Constant(value=[1])->cst__3Constant(value=[2])->cst__4Constant(value=[1])->cst__5Slice(x1,cst__3,cst__4,cst__5)->r__14Constant(value=[0])->cst__6Constant(value=[1])->cst__7Constant(value=[1])->cst__8Slice(x0,cst__6,cst__7,cst__8)->r__16Constant(value=[0])->cst__9Constant(value=[1])->cst__10Constant(value=[1])->cst__11Slice(x1,cst__9,cst__10,cst__11)->r__18Constant(value=[1])->cst__13Squeeze(r__12,cst__13)->r__20Constant(value=[1])->cst__15Squeeze(r__14,cst__15)->r__21Sub(r__20,r__21)->r__24Constant(value=[1])->cst__17Squeeze(r__16,cst__17)->r__22Constant(value=[1])->cst__19Squeeze(r__18,cst__19)->r__23Sub(r__22,r__23)->r__25Abs(r__25)->r__28ReduceSum(r__28,keepdims=0)->r__30Constant(value=2)->r__26CastLike(r__26,r__24)->r__27Pow(r__24,r__27)->r__29ReduceSum(r__29,keepdims=0)->r__31Add(r__30,r__31)->r__32output:name='r__32'type=dtype('float32')shape=None
![digraph{ orientation=portrait; size=7; nodesep=0.05; ranksep=0.25; x0 [shape=box color=red label="x0\nTensorProto.FLOAT\nshape=['', '']" fontsize=10]; x1 [shape=box color=red label="x1\nTensorProto.FLOAT\nshape=['', '']" fontsize=10]; r__32 [shape=box color=green label="r__32\nTensorProto.FLOAT" fontsize=10]; cst__0 [shape=box label="cst__0" fontsize=10]; Constant [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10]; Constant -> cst__0; cst__1 [shape=box label="cst__1" fontsize=10]; Constant1 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[2]" fontsize=10]; Constant1 -> cst__1; cst__2 [shape=box label="cst__2" fontsize=10]; Constant12 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10]; Constant12 -> cst__2; cst__3 [shape=box label="cst__3" fontsize=10]; Constant123 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10]; Constant123 -> cst__3; cst__4 [shape=box label="cst__4" fontsize=10]; Constant1234 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[2]" fontsize=10]; Constant1234 -> cst__4; cst__5 [shape=box label="cst__5" fontsize=10]; Constant12345 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10]; Constant12345 -> cst__5; cst__6 [shape=box label="cst__6" fontsize=10]; Constant123456 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[0]" fontsize=10]; Constant123456 -> cst__6; cst__7 [shape=box label="cst__7" fontsize=10]; Constant1234567 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10]; Constant1234567 -> cst__7; cst__8 [shape=box label="cst__8" fontsize=10]; Constant12345678 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10]; Constant12345678 -> cst__8; cst__9 [shape=box label="cst__9" fontsize=10]; Constant123456789 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[0]" fontsize=10]; Constant123456789 -> cst__9; cst__10 [shape=box label="cst__10" fontsize=10]; Constant12345678910 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10]; Constant12345678910 -> cst__10; cst__11 [shape=box label="cst__11" fontsize=10]; Constant1234567891011 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10]; Constant1234567891011 -> cst__11; r__12 [shape=box label="r__12" fontsize=10]; Slice [shape=box style="filled,rounded" color=orange label="Slice" fontsize=10]; x0 -> Slice; cst__0 -> Slice; cst__1 -> Slice; cst__2 -> Slice; Slice -> r__12; cst__13 [shape=box label="cst__13" fontsize=10]; Constant123456789101112 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10]; Constant123456789101112 -> cst__13; r__14 [shape=box label="r__14" fontsize=10]; Slice1 [shape=box style="filled,rounded" color=orange label="Slice" fontsize=10]; x1 -> Slice1; cst__3 -> Slice1; cst__4 -> Slice1; cst__5 -> Slice1; Slice1 -> r__14; cst__15 [shape=box label="cst__15" fontsize=10]; Constant12345678910111213 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10]; Constant12345678910111213 -> cst__15; r__16 [shape=box label="r__16" fontsize=10]; Slice12 [shape=box style="filled,rounded" color=orange label="Slice" fontsize=10]; x0 -> Slice12; cst__6 -> Slice12; cst__7 -> Slice12; cst__8 -> Slice12; Slice12 -> r__16; cst__17 [shape=box label="cst__17" fontsize=10]; Constant1234567891011121314 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10]; Constant1234567891011121314 -> cst__17; r__18 [shape=box label="r__18" fontsize=10]; Slice123 [shape=box style="filled,rounded" color=orange label="Slice" fontsize=10]; x1 -> Slice123; cst__9 -> Slice123; cst__10 -> Slice123; cst__11 -> Slice123; Slice123 -> r__18; cst__19 [shape=box label="cst__19" fontsize=10]; Constant123456789101112131415 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10]; Constant123456789101112131415 -> cst__19; r__20 [shape=box label="r__20" fontsize=10]; Squeeze [shape=box style="filled,rounded" color=orange label="Squeeze" fontsize=10]; r__12 -> Squeeze; cst__13 -> Squeeze; Squeeze -> r__20; r__21 [shape=box label="r__21" fontsize=10]; Squeeze1 [shape=box style="filled,rounded" color=orange label="Squeeze" fontsize=10]; r__14 -> Squeeze1; cst__15 -> Squeeze1; Squeeze1 -> r__21; r__22 [shape=box label="r__22" fontsize=10]; Squeeze12 [shape=box style="filled,rounded" color=orange label="Squeeze" fontsize=10]; r__16 -> Squeeze12; cst__17 -> Squeeze12; Squeeze12 -> r__22; r__23 [shape=box label="r__23" fontsize=10]; Squeeze123 [shape=box style="filled,rounded" color=orange label="Squeeze" fontsize=10]; r__18 -> Squeeze123; cst__19 -> Squeeze123; Squeeze123 -> r__23; r__24 [shape=box label="r__24" fontsize=10]; Sub [shape=box style="filled,rounded" color=orange label="Sub" fontsize=10]; r__20 -> Sub; r__21 -> Sub; Sub -> r__24; r__25 [shape=box label="r__25" fontsize=10]; Sub1 [shape=box style="filled,rounded" color=orange label="Sub" fontsize=10]; r__22 -> Sub1; r__23 -> Sub1; Sub1 -> r__25; r__26 [shape=box label="r__26" fontsize=10]; Constant12345678910111213141516 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=2" fontsize=10]; Constant12345678910111213141516 -> r__26; r__27 [shape=box label="r__27" fontsize=10]; CastLike [shape=box style="filled,rounded" color=orange label="CastLike" fontsize=10]; r__26 -> CastLike; r__24 -> CastLike; CastLike -> r__27; r__28 [shape=box label="r__28" fontsize=10]; Abs [shape=box style="filled,rounded" color=orange label="Abs" fontsize=10]; r__25 -> Abs; Abs -> r__28; r__29 [shape=box label="r__29" fontsize=10]; Pow [shape=box style="filled,rounded" color=orange label="Pow" fontsize=10]; r__24 -> Pow; r__27 -> Pow; Pow -> r__29; r__30 [shape=box label="r__30" fontsize=10]; ReduceSum [shape=box style="filled,rounded" color=orange label="ReduceSum\nkeepdims=0" fontsize=10]; r__28 -> ReduceSum; ReduceSum -> r__30; r__31 [shape=box label="r__31" fontsize=10]; ReduceSum1 [shape=box style="filled,rounded" color=orange label="ReduceSum\nkeepdims=0" fontsize=10]; r__29 -> ReduceSum1; ReduceSum1 -> r__31; Add [shape=box style="filled,rounded" color=orange label="Add" fontsize=10]; r__30 -> Add; r__31 -> Add; Add -> r__32;}](/image.pl?url=https%3a%2f%2fsdpython.github.io%2fdoc%2fonnx-array-api%2fdev%2f_images%2fgraphviz-3090b200d8f5147e9b30460213b1854077156835.png&f=jpg&w=240)