Movatterモバイル変換


[0]ホーム

URL:


ContentsMenuExpandLight modeDark modeAuto light/dark, in light modeAuto light/dark, in dark modeSkip to content
onnx-array-api 0.3.1 documentation
Logo
onnx-array-api 0.3.1 documentation

Contents

More

Back to top

plotting

Dot

onnx_array_api.plotting.dot_plot.to_dot(proto:ModelProto,recursive:bool=False,prefix:str='',use_onnx:bool=False,add_functions:bool=True,rt_shapes:Dict[str,Tuple[int,...]]|None=None,**params)str[source]

Produces aDOT language string for the graph.

Parameters:
  • params – additional params to draw the graph

  • recursive – also show subgraphs inside operator likeScan

  • prefix – prefix for every node name

  • use_onnx – useonnx dot format instead of this one

  • add_functions – add functions to the graph

  • rt_shapes – indicates shapes obtained from the execution or inference

Returns:

string

Default options for the graph are:

options={'orientation':'portrait','ranksep':'0.25','nodesep':'0.05','width':'0.5','height':'0.1','size':'7',}

One example:

<<<

importnumpyasnp# Bfromonnx_array_api.npximportabsolute,jit_onnxfromonnx_array_api.plotting.dot_plotimportto_dotdefl1_loss(x,y):returnabsolute(x-y).sum()defl2_loss(x,y):return((x-y)**2).sum()defmyloss(x,y):returnl1_loss(x[:,0],y[:,0])+l2_loss(x[:,1],y[:,1])jitted_myloss=jit_onnx(myloss)x=np.array([[0.1,0.2],[0.3,0.4]],dtype=np.float32)y=np.array([[0.11,0.22],[0.33,0.44]],dtype=np.float32)res=jitted_myloss(x,y)print(res)

>>>

0.042
digraph{  ranksep=0.25;  nodesep=0.05;  size=7;  orientation=portrait;  x0 [shape=box color=red label="x0\nTensorProto.FLOAT\nshape=['', '']" fontsize=10];  x1 [shape=box color=red label="x1\nTensorProto.FLOAT\nshape=['', '']" fontsize=10];  r__32 [shape=box color=green label="r__32\nTensorProto.FLOAT" fontsize=10];  cst__0 [shape=box label="cst__0" fontsize=10];  Constant [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];  Constant -> cst__0;  cst__1 [shape=box label="cst__1" fontsize=10];  Constant1 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[2]" fontsize=10];  Constant1 -> cst__1;  cst__2 [shape=box label="cst__2" fontsize=10];  Constant12 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];  Constant12 -> cst__2;  cst__3 [shape=box label="cst__3" fontsize=10];  Constant123 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];  Constant123 -> cst__3;  cst__4 [shape=box label="cst__4" fontsize=10];  Constant1234 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[2]" fontsize=10];  Constant1234 -> cst__4;  cst__5 [shape=box label="cst__5" fontsize=10];  Constant12345 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];  Constant12345 -> cst__5;  cst__6 [shape=box label="cst__6" fontsize=10];  Constant123456 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[0]" fontsize=10];  Constant123456 -> cst__6;  cst__7 [shape=box label="cst__7" fontsize=10];  Constant1234567 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];  Constant1234567 -> cst__7;  cst__8 [shape=box label="cst__8" fontsize=10];  Constant12345678 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];  Constant12345678 -> cst__8;  cst__9 [shape=box label="cst__9" fontsize=10];  Constant123456789 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[0]" fontsize=10];  Constant123456789 -> cst__9;  cst__10 [shape=box label="cst__10" fontsize=10];  Constant12345678910 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];  Constant12345678910 -> cst__10;  cst__11 [shape=box label="cst__11" fontsize=10];  Constant1234567891011 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];  Constant1234567891011 -> cst__11;  r__12 [shape=box label="r__12" fontsize=10];  Slice [shape=box style="filled,rounded" color=orange label="Slice" fontsize=10];  x0 -> Slice;  cst__0 -> Slice;  cst__1 -> Slice;  cst__2 -> Slice;  Slice -> r__12;  cst__13 [shape=box label="cst__13" fontsize=10];  Constant123456789101112 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];  Constant123456789101112 -> cst__13;  r__14 [shape=box label="r__14" fontsize=10];  Slice1 [shape=box style="filled,rounded" color=orange label="Slice" fontsize=10];  x1 -> Slice1;  cst__3 -> Slice1;  cst__4 -> Slice1;  cst__5 -> Slice1;  Slice1 -> r__14;  cst__15 [shape=box label="cst__15" fontsize=10];  Constant12345678910111213 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];  Constant12345678910111213 -> cst__15;  r__16 [shape=box label="r__16" fontsize=10];  Slice12 [shape=box style="filled,rounded" color=orange label="Slice" fontsize=10];  x0 -> Slice12;  cst__6 -> Slice12;  cst__7 -> Slice12;  cst__8 -> Slice12;  Slice12 -> r__16;  cst__17 [shape=box label="cst__17" fontsize=10];  Constant1234567891011121314 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];  Constant1234567891011121314 -> cst__17;  r__18 [shape=box label="r__18" fontsize=10];  Slice123 [shape=box style="filled,rounded" color=orange label="Slice" fontsize=10];  x1 -> Slice123;  cst__9 -> Slice123;  cst__10 -> Slice123;  cst__11 -> Slice123;  Slice123 -> r__18;  cst__19 [shape=box label="cst__19" fontsize=10];  Constant123456789101112131415 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=[1]" fontsize=10];  Constant123456789101112131415 -> cst__19;  r__20 [shape=box label="r__20" fontsize=10];  Squeeze [shape=box style="filled,rounded" color=orange label="Squeeze" fontsize=10];  r__12 -> Squeeze;  cst__13 -> Squeeze;  Squeeze -> r__20;  r__21 [shape=box label="r__21" fontsize=10];  Squeeze1 [shape=box style="filled,rounded" color=orange label="Squeeze" fontsize=10];  r__14 -> Squeeze1;  cst__15 -> Squeeze1;  Squeeze1 -> r__21;  r__22 [shape=box label="r__22" fontsize=10];  Squeeze12 [shape=box style="filled,rounded" color=orange label="Squeeze" fontsize=10];  r__16 -> Squeeze12;  cst__17 -> Squeeze12;  Squeeze12 -> r__22;  r__23 [shape=box label="r__23" fontsize=10];  Squeeze123 [shape=box style="filled,rounded" color=orange label="Squeeze" fontsize=10];  r__18 -> Squeeze123;  cst__19 -> Squeeze123;  Squeeze123 -> r__23;  r__24 [shape=box label="r__24" fontsize=10];  Sub [shape=box style="filled,rounded" color=orange label="Sub" fontsize=10];  r__20 -> Sub;  r__21 -> Sub;  Sub -> r__24;  r__25 [shape=box label="r__25" fontsize=10];  Sub1 [shape=box style="filled,rounded" color=orange label="Sub" fontsize=10];  r__22 -> Sub1;  r__23 -> Sub1;  Sub1 -> r__25;  r__26 [shape=box label="r__26" fontsize=10];  Constant12345678910111213141516 [shape=box style="filled,rounded" color=orange label="Constant\nvalue=2" fontsize=10];  Constant12345678910111213141516 -> r__26;  r__27 [shape=box label="r__27" fontsize=10];  CastLike [shape=box style="filled,rounded" color=orange label="CastLike" fontsize=10];  r__26 -> CastLike;  r__24 -> CastLike;  CastLike -> r__27;  r__28 [shape=box label="r__28" fontsize=10];  Abs [shape=box style="filled,rounded" color=orange label="Abs" fontsize=10];  r__25 -> Abs;  Abs -> r__28;  r__29 [shape=box label="r__29" fontsize=10];  Pow [shape=box style="filled,rounded" color=orange label="Pow" fontsize=10];  r__24 -> Pow;  r__27 -> Pow;  Pow -> r__29;  r__30 [shape=box label="r__30" fontsize=10];  ReduceSum [shape=box style="filled,rounded" color=orange label="ReduceSum\nkeepdims=0" fontsize=10];  r__28 -> ReduceSum;  ReduceSum -> r__30;  r__31 [shape=box label="r__31" fontsize=10];  ReduceSum1 [shape=box style="filled,rounded" color=orange label="ReduceSum\nkeepdims=0" fontsize=10];  r__29 -> ReduceSum1;  ReduceSum1 -> r__31;  Add [shape=box style="filled,rounded" color=orange label="Add" fontsize=10];  r__30 -> Add;  r__31 -> Add;  Add -> r__32;}
onnx_array_api.plotting.graphviz_helper.plot_dot(dot:str|ModelProto,ax:matplotlib.axis.Axis|None=None,engine:str='dot',figsize:Tuple[int,int]|None=None)matplotlib.axis.Axis[source]

Draws a dot graph into a matplotlib graph.

Parameters:
  • dot – dot graph or ModelProto

  • image – output image, None, just returns the output

  • enginedot orneato

  • figsize – figsize of ax is None

Returns:

Graphviz output or, the dot text ifimage is None

(Sourcecode,png,hires.png,pdf)

../_images/plotting-1.png

Statistics

onnx_array_api.plotting.stat_plot.plot_ort_profile(df:DataFrame,ax0:Any|None=None,ax1:Any|None=None,title:str|None=None)Any[source]

Plots time spend in computation based on dataframeproduced by functionort_profile.

Parameters:
  • df – dataframe

  • ax0 – first axis to draw time

  • ax1 – second axis to draw occurences

  • title – graph title

Returns:

ax0

SeeProfiling for an example.

Text

onnx_array_api.plotting.text_plot.onnx_text_plot_tree(node)[source]

Gives a textual representation of a tree ensemble.

Parameters:

nodeTreeEnsemble*

Returns:

text

<<<

importnumpyfromsklearn.datasetsimportload_irisfromsklearn.treeimportDecisionTreeRegressorfromskl2onnximportto_onnxfromonnx_array_api.plotting.text_plotimportonnx_text_plot_treeiris=load_iris()X,y=iris.data.astype(numpy.float32),iris.targetclr=DecisionTreeRegressor(max_depth=3)clr.fit(X,y)onx=to_onnx(clr,X)res=onnx_text_plot_tree(onx.graph.node[0])print(res)

>>>

n_targets=1n_trees=1----treeid=0nX3<=np.float32(0.8)-nX3<=np.float32(1.75)-nX2<=np.float32(4.85)-f0:2+f0:1.67+nX2<=np.float32(4.95)-f0:1.67+f0:1.02+f0:0
onnx_array_api.plotting.text_plot.onnx_text_plot_io(model,verbose=False,att_display=None)[source]

Displays information about input and output types.

Parameters:
  • model – ONNX graph

  • verbose – display debugging information

Returns:

str

An ONNX graph is printed the following way:

<<<

importnumpyfromsklearn.clusterimportKMeansfromskl2onnximportto_onnxfromonnx_array_api.plotting.text_plotimportonnx_text_plot_iox=numpy.random.randn(10,3)y=numpy.random.randn(10)model=KMeans(3)model.fit(x,y)onx=to_onnx(model,x.astype(numpy.float32),target_opset=15)text=onnx_text_plot_io(onx,verbose=False)print(text)

>>>

opset:domain=''version=15input:name='X'type=dtype('float32')shape=['',3]init:name='Ad_Addcst'type=float32shape=(3,)init:name='Ge_Gemmcst'type=float32shape=(3,3)init:name='Mu_Mulcst'type=float32shape=(1,)output:name='label'type=dtype('int64')shape=['']output:name='scores'type=dtype('float32')shape=['',3]
onnx_array_api.plotting.text_plot.onnx_simple_text_plot(model,verbose=False,att_display=None,add_links=False,recursive=False,functions=True,raise_exc=True,sub_graphs_names=None,level=1,indent=True)[source]

Displays an ONNX graph into text.

Parameters:
  • model – ONNX graph

  • verbose – display debugging information

  • att_display – list of attributes to display, if None,a default list if used

  • add_links – displays links of the right side

  • recursive – display subgraphs as well

  • functions – display functions as well

  • raise_exc – raises an exception if the model is not valid,otherwise tries to continue

  • sub_graphs_names – list of sub-graphs names

  • level – sub-graph level

  • indent – use indentation or not

Returns:

str

An ONNX graph is printed the following way:

<<<

importnumpyfromsklearn.clusterimportKMeansfromskl2onnximportto_onnxfromonnx_array_api.plotting.text_plotimportonnx_simple_text_plotx=numpy.random.randn(10,3)y=numpy.random.randn(10)model=KMeans(3)model.fit(x,y)onx=to_onnx(model,x.astype(numpy.float32),target_opset=15)text=onnx_simple_text_plot(onx,verbose=False)print(text)

>>>

opset:domain=''version=15input:name='X'type=dtype('float32')shape=['',3]init:name='Ad_Addcst'type=float32shape=(3,)--array([1.994,1.628,3.194],dtype=float32)init:name='Ge_Gemmcst'type=float32shape=(3,3)init:name='Mu_Mulcst'type=float32shape=(1,)--array([0.],dtype=float32)ReduceSumSquare(X,axes=[1],keepdims=1)->Re_reduced0Mul(Re_reduced0,Mu_Mulcst)->Mu_C0Gemm(X,Ge_Gemmcst,Mu_C0,alpha=-2.00,transB=1)->Ge_Y0Add(Re_reduced0,Ge_Y0)->Ad_C01Add(Ad_Addcst,Ad_C01)->Ad_C0ArgMin(Ad_C0,axis=1,keepdims=0)->labelSqrt(Ad_C0)->scoresoutput:name='label'type=dtype('int64')shape=['']output:name='scores'type=dtype('float32')shape=['',3]

The same graphs with links.

<<<

importnumpyfromsklearn.clusterimportKMeansfromskl2onnximportto_onnxfromonnx_array_api.plotting.text_plotimportonnx_simple_text_plotx=numpy.random.randn(10,3)y=numpy.random.randn(10)model=KMeans(3)model.fit(x,y)onx=to_onnx(model,x.astype(numpy.float32),target_opset=15)text=onnx_simple_text_plot(onx,verbose=False,add_links=True)print(text)

>>>

opset:domain=''version=15input:name='X'type=dtype('float32')shape=['',3]-------------------------------------------+-+init:name='Ad_Addcst'type=float32shape=(3,)--array([0.959,0.651,6.432],dtype=float32)|-|-----------+init:name='Ge_Gemmcst'type=float32shape=(3,3)------------------------------------+|||init:name='Mu_Mulcst'type=float32shape=(1,)--array([0.],dtype=float32)-----+||||ReduceSumSquare(X,axes=[1],keepdims=1)->Re_reduced0<--+----------------------|-+-|--------+||Mul(Re_reduced0,Mu_Mulcst)->Mu_C0<-------------------+----------------------+||||Gemm(X,Ge_Gemmcst,Mu_C0,alpha=-2.00,transB=1)->Ge_Y0<-------------------|-+----------+|Add(Re_reduced0,Ge_Y0)->Ad_C01<-----------------------------------------------+|Add(Ad_Addcst,Ad_C01)->Ad_C0----------------+-+------------------------------------------------------+ArgMin(Ad_C0,axis=1,keepdims=0)->label<--+-|--+Sqrt(Ad_C0)->scores<-------------------------+--|-----+output:name='label'type=dtype('int64')shape=['']<----+|output:name='scores'type=dtype('float32')shape=['',3]<----+

Visually, it looks like the following:

digraph{  size=7;  ranksep=0.25;  nodesep=0.05;  orientation=portrait;  X [shape=box color=red label="X\nTensorProto.FLOAT\nshape=['', 3]" fontsize=10];  label [shape=box color=green label="label\nTensorProto.INT64\nshape=['']" fontsize=10];  scores [shape=box color=green label="scores\nTensorProto.FLOAT\nshape=['', 3]" fontsize=10];  Ad_Addcst [shape=box label="Ad_Addcst\nfloat32((3,))\n[3.671 1.108 0.756]" fontsize=10];  Ge_Gemmcst [shape=box label="Ge_Gemmcst\nfloat32((3, 3))\n[[ 1.414 -0.672  1.106]\n [ 0.018  0.92   0.511]\n [..." fontsize=10];  Mu_Mulcst [shape=box label="Mu_Mulcst\nfloat32((1,))\n[0.]" fontsize=10];  Re_reduced0 [shape=box label="Re_reduced0" fontsize=10];  Re_ReduceSumSquare [shape=box style="filled,rounded" color=orange label="ReduceSumSquare\naxes=[1]\nkeepdims=1" fontsize=10];  X -> Re_ReduceSumSquare;  Re_ReduceSumSquare -> Re_reduced0;  Mu_C0 [shape=box label="Mu_C0" fontsize=10];  Mu_Mul [shape=box style="filled,rounded" color=orange label="Mul" fontsize=10];  Re_reduced0 -> Mu_Mul;  Mu_Mulcst -> Mu_Mul;  Mu_Mul -> Mu_C0;  Ge_Y0 [shape=box label="Ge_Y0" fontsize=10];  Ge_Gemm [shape=box style="filled,rounded" color=orange label="Gemm\nalpha=-2.0\ntransB=1" fontsize=10];  X -> Ge_Gemm;  Ge_Gemmcst -> Ge_Gemm;  Mu_C0 -> Ge_Gemm;  Ge_Gemm -> Ge_Y0;  Ad_C01 [shape=box label="Ad_C01" fontsize=10];  Ad_Add [shape=box style="filled,rounded" color=orange label="Add" fontsize=10];  Re_reduced0 -> Ad_Add;  Ge_Y0 -> Ad_Add;  Ad_Add -> Ad_C01;  Ad_C0 [shape=box label="Ad_C0" fontsize=10];  Ad_Add1 [shape=box style="filled,rounded" color=orange label="Add" fontsize=10];  Ad_Addcst -> Ad_Add1;  Ad_C01 -> Ad_Add1;  Ad_Add1 -> Ad_C0;  Ar_ArgMin [shape=box style="filled,rounded" color=orange label="ArgMin\naxis=1\nkeepdims=0" fontsize=10];  Ad_C0 -> Ar_ArgMin;  Ar_ArgMin -> label;  Sq_Sqrt [shape=box style="filled,rounded" color=orange label="Sqrt" fontsize=10];  Ad_C0 -> Sq_Sqrt;  Sq_Sqrt -> scores;}
On this page

[8]ページ先頭

©2009-2025 Movatter.jp