

onnx-array-api implements APIs to create custom ONNX graphs.The objective is to speed up the implementation of converter libraries.
Sources available ongithub/onnx-array-api.
<<<
importnumpyasnp# Afromonnx_array_api.npximportabsolute,jit_onnxfromonnx_array_api.plotting.text_plotimportonnx_simple_text_plotdefl1_loss(x,y):returnabsolute(x-y).sum()defl2_loss(x,y):return((x-y)**2).sum()defmyloss(x,y):returnl1_loss(x[:,0],y[:,0])+l2_loss(x[:,1],y[:,1])jitted_myloss=jit_onnx(myloss)x=np.array([[0.1,0.2],[0.3,0.4]],dtype=np.float32)y=np.array([[0.11,0.22],[0.33,0.44]],dtype=np.float32)res=jitted_myloss(x,y)print(res)print(onnx_simple_text_plot(jitted_myloss.get_onnx()))
>>>
0.042opset:domain=''version=18input:name='x0'type=dtype('float32')shape=['','']input:name='x1'type=dtype('float32')shape=['','']Constant(value=[1])->cst__0Constant(value=[2])->cst__1Constant(value=[1])->cst__2Slice(x0,cst__0,cst__1,cst__2)->r__12Constant(value=[1])->cst__3Constant(value=[2])->cst__4Constant(value=[1])->cst__5Slice(x1,cst__3,cst__4,cst__5)->r__14Constant(value=[0])->cst__6Constant(value=[1])->cst__7Constant(value=[1])->cst__8Slice(x0,cst__6,cst__7,cst__8)->r__16Constant(value=[0])->cst__9Constant(value=[1])->cst__10Constant(value=[1])->cst__11Slice(x1,cst__9,cst__10,cst__11)->r__18Constant(value=[1])->cst__13Squeeze(r__12,cst__13)->r__20Constant(value=[1])->cst__15Squeeze(r__14,cst__15)->r__21Sub(r__20,r__21)->r__24Constant(value=[1])->cst__17Squeeze(r__16,cst__17)->r__22Constant(value=[1])->cst__19Squeeze(r__18,cst__19)->r__23Sub(r__22,r__23)->r__25Abs(r__25)->r__28ReduceSum(r__28,keepdims=0)->r__30Constant(value=2)->r__26CastLike(r__26,r__24)->r__27Pow(r__24,r__27)->r__29ReduceSum(r__29,keepdims=0)->r__31Add(r__30,r__31)->r__32output:name='r__32'type=dtype('float32')shape=None
![digraph{ orientation=portrait; ranksep=0.25; nodesep=0.05; size=7; x0 [shape=box color=red label="x0\nTensorProto.FLOAT\nshape=['', '']" fontsize=10]; x1 [shape=box color=red label="x1\nTensorProto.FLOAT\nshape=['', '']" fontsize=10]; r__32 [shape=box color=green label="r__32\nTensorProto.FLOAT" fontsize=10]; cst__0 [shape=box label="cst__0" fontsize=10]; Constant [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10]; Constant -> cst__0; cst__1 [shape=box label="cst__1" fontsize=10]; Constant1 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[2]" fontsize=10]; Constant1 -> cst__1; cst__2 [shape=box label="cst__2" fontsize=10]; Constant12 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10]; Constant12 -> cst__2; cst__3 [shape=box label="cst__3" fontsize=10]; Constant123 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10]; Constant123 -> cst__3; cst__4 [shape=box label="cst__4" fontsize=10]; Constant1234 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[2]" fontsize=10]; Constant1234 -> cst__4; cst__5 [shape=box label="cst__5" fontsize=10]; Constant12345 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10]; Constant12345 -> cst__5; cst__6 [shape=box label="cst__6" fontsize=10]; Constant123456 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[0]" fontsize=10]; Constant123456 -> cst__6; cst__7 [shape=box label="cst__7" fontsize=10]; Constant1234567 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10]; Constant1234567 -> cst__7; cst__8 [shape=box label="cst__8" fontsize=10]; Constant12345678 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10]; Constant12345678 -> cst__8; cst__9 [shape=box label="cst__9" fontsize=10]; Constant123456789 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[0]" fontsize=10]; Constant123456789 -> cst__9; cst__10 [shape=box label="cst__10" fontsize=10]; Constant12345678910 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10]; Constant12345678910 -> cst__10; cst__11 [shape=box label="cst__11" fontsize=10]; Constant1234567891011 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10]; Constant1234567891011 -> cst__11; r__12 [shape=box label="r__12" fontsize=10]; Slice [shape=box style="filled,rounded" color=orange label="Slice" fontsize=10]; x0 -> Slice; cst__0 -> Slice; cst__1 -> Slice; cst__2 -> Slice; Slice -> r__12; cst__13 [shape=box label="cst__13" fontsize=10]; Constant123456789101112 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10]; Constant123456789101112 -> cst__13; r__14 [shape=box label="r__14" fontsize=10]; Slice1 [shape=box style="filled,rounded" color=orange label="Slice" fontsize=10]; x1 -> Slice1; cst__3 -> Slice1; cst__4 -> Slice1; cst__5 -> Slice1; Slice1 -> r__14; cst__15 [shape=box label="cst__15" fontsize=10]; Constant12345678910111213 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10]; Constant12345678910111213 -> cst__15; r__16 [shape=box label="r__16" fontsize=10]; Slice12 [shape=box style="filled,rounded" color=orange label="Slice" fontsize=10]; x0 -> Slice12; cst__6 -> Slice12; cst__7 -> Slice12; cst__8 -> Slice12; Slice12 -> r__16; cst__17 [shape=box label="cst__17" fontsize=10]; Constant1234567891011121314 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10]; Constant1234567891011121314 -> cst__17; r__18 [shape=box label="r__18" fontsize=10]; Slice123 [shape=box style="filled,rounded" color=orange label="Slice" fontsize=10]; x1 -> Slice123; cst__9 -> Slice123; cst__10 -> Slice123; cst__11 -> Slice123; Slice123 -> r__18; cst__19 [shape=box label="cst__19" fontsize=10]; Constant123456789101112131415 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10]; Constant123456789101112131415 -> cst__19; r__20 [shape=box label="r__20" fontsize=10]; Squeeze [shape=box style="filled,rounded" color=orange label="Squeeze" fontsize=10]; r__12 -> Squeeze; cst__13 -> Squeeze; Squeeze -> r__20; r__21 [shape=box label="r__21" fontsize=10]; Squeeze1 [shape=box style="filled,rounded" color=orange label="Squeeze" fontsize=10]; r__14 -> Squeeze1; cst__15 -> Squeeze1; Squeeze1 -> r__21; r__22 [shape=box label="r__22" fontsize=10]; Squeeze12 [shape=box style="filled,rounded" color=orange label="Squeeze" fontsize=10]; r__16 -> Squeeze12; cst__17 -> Squeeze12; Squeeze12 -> r__22; r__23 [shape=box label="r__23" fontsize=10]; Squeeze123 [shape=box style="filled,rounded" color=orange label="Squeeze" fontsize=10]; r__18 -> Squeeze123; cst__19 -> Squeeze123; Squeeze123 -> r__23; r__24 [shape=box label="r__24" fontsize=10]; Sub [shape=box style="filled,rounded" color=orange label="Sub" fontsize=10]; r__20 -> Sub; r__21 -> Sub; Sub -> r__24; r__25 [shape=box label="r__25" fontsize=10]; Sub1 [shape=box style="filled,rounded" color=orange label="Sub" fontsize=10]; r__22 -> Sub1; r__23 -> Sub1; Sub1 -> r__25; r__26 [shape=box label="r__26" fontsize=10]; Constant12345678910111213141516 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=2" fontsize=10]; Constant12345678910111213141516 -> r__26; r__27 [shape=box label="r__27" fontsize=10]; CastLike [shape=box style="filled,rounded" color=orange label="CastLike" fontsize=10]; r__26 -> CastLike; r__24 -> CastLike; CastLike -> r__27; r__28 [shape=box label="r__28" fontsize=10]; Abs [shape=box style="filled,rounded" color=orange label="Abs" fontsize=10]; r__25 -> Abs; Abs -> r__28; r__29 [shape=box label="r__29" fontsize=10]; Pow [shape=box style="filled,rounded" color=orange label="Pow" fontsize=10]; r__24 -> Pow; r__27 -> Pow; Pow -> r__29; r__30 [shape=box label="r__30" fontsize=10]; ReduceSum [shape=box style="filled,rounded" color=orange label="ReduceSum\nkeepdims=0" fontsize=10]; r__28 -> ReduceSum; ReduceSum -> r__30; r__31 [shape=box label="r__31" fontsize=10]; ReduceSum1 [shape=box style="filled,rounded" color=orange label="ReduceSum\nkeepdims=0" fontsize=10]; r__29 -> ReduceSum1; ReduceSum1 -> r__31; Add [shape=box style="filled,rounded" color=orange label="Add" fontsize=10]; r__30 -> Add; r__31 -> Add; Add -> r__32;}](/image.pl?url=https%3a%2f%2fsdpython.github.io%2fdoc%2fonnx-array-api%2fdev%2ftutorial%2f..%2ftutorial%2f..%2f..%2fv0.1.3%2f_images%2fgraphviz-8aae71c054187835c6bebb725974c370d7109aab.png&f=jpg&w=240)
<<<
importnumpyasnpfromonnx_array_api.light_apiimportstartfromonnx_array_api.plotting.text_plotimportonnx_simple_text_plotmodel=(start().vin("X").vin("Y").bring("X","Y").Sub().rename("dxy").cst(np.array([2],dtype=np.int64),"two").bring("dxy","two").Pow().ReduceSum().rename("Z").vout().to_onnx())print(onnx_simple_text_plot(model))
>>>
opset:domain=''version=20input:name='X'type=dtype('float32')shape=Noneinput:name='Y'type=dtype('float32')shape=Noneinit:name='two'type=dtype('int64')shape=(1,)--array([2])Sub(X,Y)->dxyPow(dxy,two)->r1_0ReduceSum(r1_0,keepdims=1,noop_with_empty_axes=0)->Zoutput:name='Z'type=dtype('float32')shape=None