Movatterモバイル変換


[0]ホーム

URL:


ContentsMenuExpandLight modeDark modeAuto light/dark, in light modeAuto light/dark, in dark modeSkip to content
onnx-array-api 0.3.1 documentation
Logo
onnx-array-api 0.3.1 documentation

Contents

More

Back to top

onnx-array-api: APIs to create ONNX Graphs

https://dev.azure.com/xavierdupre3/onnx-array-api/_apis/build/status/sdpython.onnx-array-apihttps://badge.fury.io/py/onnx-array-api.svgGitHub IssuesMIT Licensesizehttps://img.shields.io/badge/code%20style-black-000000.svghttps://codecov.io/gh/sdpython/onnx-array-api/branch/main/graph/badge.svg?token=Wb9ZGDta8J

onnx-array-api implements APIs to create custom ONNX graphs.The objective is to speed up the implementation of converter libraries.

Sources available ongithub/onnx-array-api.

GraphBuilder API

Almost every converting library (converting a machine learned model to ONNX) is implementingits own graph builder and customizes it for its needs.It handles some frequent tasks such as giving names to intermediateresults, loading, saving onnx models. It can be used as well to extend an existing graph.SeeGraphBuilder: common API for ONNX.

<<<

importnumpyasnpfromonnx_array_api.graph_apiimportGraphBuilderfromonnx_array_api.plotting.text_plotimportonnx_simple_text_plotg=GraphBuilder()g.make_tensor_input("X",np.float32,(None,None))g.make_tensor_input("Y",np.float32,(None,None))r1=g.make_node("Sub",["X","Y"])# the name given to the output is given by the class,# it ensures the name is uniqueinit=g.make_initializer(np.array([2],dtype=np.int64))# the class automatically# converts the array to a tensorr2=g.make_node("Pow",[r1,init])g.make_node("ReduceSum",[r2],outputs=["Z"])# the output name is given because# the user wants to choose the nameg.make_tensor_output("Z",np.float32,(None,None))onx=g.to_onnx()# final conversion to onnxprint(onnx_simple_text_plot(onx))

>>>

opset:domain=''version=22input:name='X'type=dtype('float32')shape=['','']input:name='Y'type=dtype('float32')shape=['','']init:name='cst'type=int64shape=(1,)--array([2])Sub(X,Y)->_onx_sub0Pow(_onx_sub0,cst)->_onx_pow0ReduceSum(_onx_pow0)->Zoutput:name='Z'type=dtype('float32')shape=['','']

Light API

The syntax is inspired from theReverse Polish Notation.This kind of API is easy to use to build new graphs,less easy to extend an existing graph. SeeLight API for ONNX: everything in one line.

<<<

importnumpyasnpfromonnx_array_api.light_apiimportstartfromonnx_array_api.plotting.text_plotimportonnx_simple_text_plotmodel=(start().vin("X").vin("Y").bring("X","Y").Sub().rename("dxy").cst(np.array([2],dtype=np.int64),"two").bring("dxy","two").Pow().ReduceSum().rename("Z").vout().to_onnx())print(onnx_simple_text_plot(model))

>>>

opset:domain=''version=22input:name='X'type=dtype('float32')shape=Noneinput:name='Y'type=dtype('float32')shape=Noneinit:name='two'type=int64shape=(1,)--array([2])Sub(X,Y)->dxyPow(dxy,two)->r1_0ReduceSum(r1_0,keepdims=1,noop_with_empty_axes=0)->Zoutput:name='Z'type=dtype('float32')shape=None

Numpy API

Writing ONNX graphs requires to know ONNX syntax unlessit is possible to reuse an existing syntax such asnumpy.This is what this API is doing.This kind of API is easy to use to build new graphs,almost impossible to use to extend new graphs as it usually requiresto know onnx for that. SeeNumpy API for ONNX.

<<<

importnumpyasnp# Afromonnx_array_api.npximportabsolute,jit_onnxfromonnx_array_api.plotting.text_plotimportonnx_simple_text_plotdefl1_loss(x,y):returnabsolute(x-y).sum()defl2_loss(x,y):return((x-y)**2).sum()defmyloss(x,y):returnl1_loss(x[:,0],y[:,0])+l2_loss(x[:,1],y[:,1])jitted_myloss=jit_onnx(myloss)x=np.array([[0.1,0.2],[0.3,0.4]],dtype=np.float32)y=np.array([[0.11,0.22],[0.33,0.44]],dtype=np.float32)res=jitted_myloss(x,y)print(res)print(onnx_simple_text_plot(jitted_myloss.get_onnx()))

>>>

0.042opset:domain=''version=18input:name='x0'type=dtype('float32')shape=['','']input:name='x1'type=dtype('float32')shape=['','']Constant(value=[1])->cst__0Constant(value=[2])->cst__1Constant(value=[1])->cst__2Slice(x0,cst__0,cst__1,cst__2)->r__12Constant(value=[1])->cst__3Constant(value=[2])->cst__4Constant(value=[1])->cst__5Slice(x1,cst__3,cst__4,cst__5)->r__14Constant(value=[0])->cst__6Constant(value=[1])->cst__7Constant(value=[1])->cst__8Slice(x0,cst__6,cst__7,cst__8)->r__16Constant(value=[0])->cst__9Constant(value=[1])->cst__10Constant(value=[1])->cst__11Slice(x1,cst__9,cst__10,cst__11)->r__18Constant(value=[1])->cst__13Squeeze(r__12,cst__13)->r__20Constant(value=[1])->cst__15Squeeze(r__14,cst__15)->r__21Sub(r__20,r__21)->r__24Constant(value=[1])->cst__17Squeeze(r__16,cst__17)->r__22Constant(value=[1])->cst__19Squeeze(r__18,cst__19)->r__23Sub(r__22,r__23)->r__25Abs(r__25)->r__28ReduceSum(r__28,keepdims=0)->r__30Constant(value=2)->r__26CastLike(r__26,r__24)->r__27Pow(r__24,r__27)->r__29ReduceSum(r__29,keepdims=0)->r__31Add(r__30,r__31)->r__32output:name='r__32'type=dtype('float32')shape=None
digraph{  orientation=portrait;  size=7;  nodesep=0.05;  ranksep=0.25;  x0 [shape=box color=red label="x0\nTensorProto.FLOAT\nshape=['', '']" fontsize=10];  x1 [shape=box color=red label="x1\nTensorProto.FLOAT\nshape=['', '']" fontsize=10];  r__32 [shape=box color=green label="r__32\nTensorProto.FLOAT" fontsize=10];  cst__0 [shape=box label="cst__0" fontsize=10];  Constant [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];  Constant -> cst__0;  cst__1 [shape=box label="cst__1" fontsize=10];  Constant1 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[2]" fontsize=10];  Constant1 -> cst__1;  cst__2 [shape=box label="cst__2" fontsize=10];  Constant12 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];  Constant12 -> cst__2;  cst__3 [shape=box label="cst__3" fontsize=10];  Constant123 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];  Constant123 -> cst__3;  cst__4 [shape=box label="cst__4" fontsize=10];  Constant1234 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[2]" fontsize=10];  Constant1234 -> cst__4;  cst__5 [shape=box label="cst__5" fontsize=10];  Constant12345 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];  Constant12345 -> cst__5;  cst__6 [shape=box label="cst__6" fontsize=10];  Constant123456 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[0]" fontsize=10];  Constant123456 -> cst__6;  cst__7 [shape=box label="cst__7" fontsize=10];  Constant1234567 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];  Constant1234567 -> cst__7;  cst__8 [shape=box label="cst__8" fontsize=10];  Constant12345678 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];  Constant12345678 -> cst__8;  cst__9 [shape=box label="cst__9" fontsize=10];  Constant123456789 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[0]" fontsize=10];  Constant123456789 -> cst__9;  cst__10 [shape=box label="cst__10" fontsize=10];  Constant12345678910 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];  Constant12345678910 -> cst__10;  cst__11 [shape=box label="cst__11" fontsize=10];  Constant1234567891011 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];  Constant1234567891011 -> cst__11;  r__12 [shape=box label="r__12" fontsize=10];  Slice [shape=box style="filled,rounded" color=orange label="Slice" fontsize=10];  x0 -> Slice;  cst__0 -> Slice;  cst__1 -> Slice;  cst__2 -> Slice;  Slice -> r__12;  cst__13 [shape=box label="cst__13" fontsize=10];  Constant123456789101112 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];  Constant123456789101112 -> cst__13;  r__14 [shape=box label="r__14" fontsize=10];  Slice1 [shape=box style="filled,rounded" color=orange label="Slice" fontsize=10];  x1 -> Slice1;  cst__3 -> Slice1;  cst__4 -> Slice1;  cst__5 -> Slice1;  Slice1 -> r__14;  cst__15 [shape=box label="cst__15" fontsize=10];  Constant12345678910111213 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];  Constant12345678910111213 -> cst__15;  r__16 [shape=box label="r__16" fontsize=10];  Slice12 [shape=box style="filled,rounded" color=orange label="Slice" fontsize=10];  x0 -> Slice12;  cst__6 -> Slice12;  cst__7 -> Slice12;  cst__8 -> Slice12;  Slice12 -> r__16;  cst__17 [shape=box label="cst__17" fontsize=10];  Constant1234567891011121314 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];  Constant1234567891011121314 -> cst__17;  r__18 [shape=box label="r__18" fontsize=10];  Slice123 [shape=box style="filled,rounded" color=orange label="Slice" fontsize=10];  x1 -> Slice123;  cst__9 -> Slice123;  cst__10 -> Slice123;  cst__11 -> Slice123;  Slice123 -> r__18;  cst__19 [shape=box label="cst__19" fontsize=10];  Constant123456789101112131415 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];  Constant123456789101112131415 -> cst__19;  r__20 [shape=box label="r__20" fontsize=10];  Squeeze [shape=box style="filled,rounded" color=orange label="Squeeze" fontsize=10];  r__12 -> Squeeze;  cst__13 -> Squeeze;  Squeeze -> r__20;  r__21 [shape=box label="r__21" fontsize=10];  Squeeze1 [shape=box style="filled,rounded" color=orange label="Squeeze" fontsize=10];  r__14 -> Squeeze1;  cst__15 -> Squeeze1;  Squeeze1 -> r__21;  r__22 [shape=box label="r__22" fontsize=10];  Squeeze12 [shape=box style="filled,rounded" color=orange label="Squeeze" fontsize=10];  r__16 -> Squeeze12;  cst__17 -> Squeeze12;  Squeeze12 -> r__22;  r__23 [shape=box label="r__23" fontsize=10];  Squeeze123 [shape=box style="filled,rounded" color=orange label="Squeeze" fontsize=10];  r__18 -> Squeeze123;  cst__19 -> Squeeze123;  Squeeze123 -> r__23;  r__24 [shape=box label="r__24" fontsize=10];  Sub [shape=box style="filled,rounded" color=orange label="Sub" fontsize=10];  r__20 -> Sub;  r__21 -> Sub;  Sub -> r__24;  r__25 [shape=box label="r__25" fontsize=10];  Sub1 [shape=box style="filled,rounded" color=orange label="Sub" fontsize=10];  r__22 -> Sub1;  r__23 -> Sub1;  Sub1 -> r__25;  r__26 [shape=box label="r__26" fontsize=10];  Constant12345678910111213141516 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=2" fontsize=10];  Constant12345678910111213141516 -> r__26;  r__27 [shape=box label="r__27" fontsize=10];  CastLike [shape=box style="filled,rounded" color=orange label="CastLike" fontsize=10];  r__26 -> CastLike;  r__24 -> CastLike;  CastLike -> r__27;  r__28 [shape=box label="r__28" fontsize=10];  Abs [shape=box style="filled,rounded" color=orange label="Abs" fontsize=10];  r__25 -> Abs;  Abs -> r__28;  r__29 [shape=box label="r__29" fontsize=10];  Pow [shape=box style="filled,rounded" color=orange label="Pow" fontsize=10];  r__24 -> Pow;  r__27 -> Pow;  Pow -> r__29;  r__30 [shape=box label="r__30" fontsize=10];  ReduceSum [shape=box style="filled,rounded" color=orange label="ReduceSum\nkeepdims=0" fontsize=10];  r__28 -> ReduceSum;  ReduceSum -> r__30;  r__31 [shape=box label="r__31" fontsize=10];  ReduceSum1 [shape=box style="filled,rounded" color=orange label="ReduceSum\nkeepdims=0" fontsize=10];  r__29 -> ReduceSum1;  ReduceSum1 -> r__31;  Add [shape=box style="filled,rounded" color=orange label="Add" fontsize=10];  r__30 -> Add;  r__31 -> Add;  Add -> r__32;}

Older versions

On this page

[8]ページ先頭

©2009-2025 Movatter.jp