Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitc74899f

Browse files
committed
Adds graph API to the tutorial
1 parent954b959 commitc74899f

File tree

5 files changed

+119
-37
lines changed

5 files changed

+119
-37
lines changed

‎_doc/tutorial/graph_api.rst

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
.. _l-graph-api:
2+
3+
=================================
4+
GraphBuilder: common API for ONNX
5+
=================================
6+
7+
This is a very common way to build ONNX graph. There are some
8+
annoying steps while building an ONNX graph. The first one is to
9+
give unique names to every intermediate result in the graph. The second
10+
is the conversion from numpy arrays to onnx tensors. A *graph builder*,
11+
here implemented by class
12+
:class:`GraphBuilder <onnx_array_api.graph_api.GraphBuilder>`
13+
usually makes these two frequent tasks easier.
14+
15+
..runpython::
16+
:showcode:
17+
18+
import numpy as np
19+
from onnx_array_api.graph_api import GraphBuilder
20+
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
21+
22+
g = GraphBuilder()
23+
g.make_tensor_input("X", np.float32, (None, None))
24+
g.make_tensor_input("Y", np.float32, (None, None))
25+
r1 = g.make_node("Sub", ["X", "Y"]) # the name given to the output is given by the class,
26+
# it ensures the name is unique
27+
init = g.make_initializer(np.array([2], dtype=np.int64)) # the class automatically
28+
# converts the array to a tensor
29+
r2 = g.make_node("Pow", [r1, init])
30+
g.make_node("ReduceSum", [r2], outputs=["Z"]) # the output name is given because
31+
# the user wants to choose the name
32+
g.make_tensor_output("Z", np.float32, (None, None))
33+
34+
onx = g.to_onnx() # final conversion to onnx
35+
36+
print(onnx_simple_text_plot(onx))
37+
38+
A more simple versions of the same code to produce the same graph.
39+
40+
..runpython::
41+
:showcode:
42+
43+
import numpy as np
44+
from onnx_array_api.graph_api import GraphBuilder
45+
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
46+
47+
g = GraphBuilder()
48+
g.make_tensor_input("X", np.float32, (None, None))
49+
g.make_tensor_input("Y", np.float32, (None, None))
50+
r1 = g.op.Sub("X", "Y") # the method name indicates which operator to use,
51+
# this can be used when there is no ambiguity about the
52+
# number of outputs
53+
r2 = g.op.Pow(r1, np.array([2], dtype=np.int64))
54+
g.op.ReduceSum(r2, outputs=["Z"]) # the still wants the user to specify the name
55+
g.make_tensor_output("Z", np.float32, (None, None))
56+
57+
onx = g.to_onnx()
58+
59+
print(onnx_simple_text_plot(onx))

‎_doc/tutorial/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ Tutorial
77
:maxdepth:1
88

99
onnx_api
10+
graph_api
1011
light_api
1112
numpy_api
1213
benchmarks

‎_doc/tutorial/onnx_api.rst

Lines changed: 48 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -584,37 +584,31 @@ The second part modifies it.
584584
585585
onnx.save(gs.export_onnx(graph),"modified.onnx")
586586
587-
numpy API for onnx
588-
++++++++++++++++++
587+
Graph Builder API
588+
+++++++++++++++++
589589

590-
See:ref:`l-numpy-api-onnx`. This API was introduced to create graphs
591-
by using numpy API. If a function is defined only with numpy,
592-
it should be possible to use the exact same code to create the
593-
corresponding onnx graph. That's what this API tries to achieve.
594-
It works with the exception of control flow. In that case, the function
595-
produces different onnx graphs depending on the execution path.
590+
See:ref:`l-graph-api`. This API is very similar to what *skl2onnx* implements.
591+
It is still about adding nodes to a graph but some tasks are automated such as
592+
naming the results or converting constants to onnx classes.
596593

597594
..runpython::
598595
:showcode:
599596

600597
import numpy as np
601-
from onnx_array_api.npximportjit_onnx
598+
from onnx_array_api.graph_apiimportGraphBuilder
602599
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
603600

604-
def l2_loss(x, y):
605-
return ((x - y) ** 2).sum(keepdims=1)
606-
607-
jitted_myloss =jit_onnx(l2_loss)
608-
dummy = np.array([0], dtype=np.float32)
609-
610-
# The function is executed. Only then a onnx graph is created.
611-
# One is created depending on the input type.
612-
jitted_myloss(dummy, dummy)
601+
g = GraphBuilder()
602+
g.make_tensor_input("X", np.float32, (None, None))
603+
g.make_tensor_input("Y", np.float32, (None, None))
604+
r1 =g.op.Sub("X", "Y")
605+
r2 =g.op.Pow(r1,np.array([2], dtype=np.int64))
606+
g.op.ReduceSum(r2, outputs=["Z"])
607+
g.make_tensor_output("Z", np.float32, (None, None))
608+
609+
onx = g.to_onnx()
613610

614-
# get_onnx only works if it was executed once or at least with
615-
# the same input type
616-
model = jitted_myloss.get_onnx()
617-
print(onnx_simple_text_plot(model))
611+
print(onnx_simple_text_plot(onx))
618612

619613
Light API
620614
+++++++++
@@ -647,3 +641,35 @@ There is no eager mode.
647641
)
648642

649643
print(onnx_simple_text_plot(model))
644+
645+
numpy API for onnx
646+
++++++++++++++++++
647+
648+
See:ref:`l-numpy-api-onnx`. This API was introduced to create graphs
649+
by using numpy API. If a function is defined only with numpy,
650+
it should be possible to use the exact same code to create the
651+
corresponding onnx graph. That's what this API tries to achieve.
652+
It works with the exception of control flow. In that case, the function
653+
produces different onnx graphs depending on the execution path.
654+
655+
..runpython::
656+
:showcode:
657+
658+
import numpy as np
659+
from onnx_array_api.npx import jit_onnx
660+
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
661+
662+
def l2_loss(x, y):
663+
return ((x - y) ** 2).sum(keepdims=1)
664+
665+
jitted_myloss = jit_onnx(l2_loss)
666+
dummy = np.array([0], dtype=np.float32)
667+
668+
# The function is executed. Only then a onnx graph is created.
669+
# One is created depending on the input type.
670+
jitted_myloss(dummy, dummy)
671+
672+
# get_onnx only works if it was executed once or at least with
673+
# the same input type
674+
model = jitted_myloss.get_onnx()
675+
print(onnx_simple_text_plot(model))

‎onnx_array_api/graph_api/graph_builder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ class Opset:
5050
"Mul":1,
5151
"Log":1,
5252
"Or":1,
53+
"Pow":1,
5354
"Relu":1,
55+
"ReduceSum":1,
5456
"Reshape":1,
5557
"Shape":1,
5658
"Slice":1,

‎onnx_array_api/plotting/text_plot.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,7 @@ def iterate(nodes, node, depth=0, true_false=""):
184184
rows.extend(r)
185185
return"\n".join(rows)
186186

187-
raiseNotImplementedError(# pragma: no cover
188-
f"Type{node.op_type!r} cannot be displayed."
189-
)
187+
raiseNotImplementedError(f"Type{node.op_type!r} cannot be displayed.")
190188

191189

192190
def_append_succ_pred(
@@ -403,7 +401,7 @@ def _find_sequence(node_name, known, done):
403401
)
404402

405403
ifnotsequences:
406-
raiseRuntimeError(# pragma: no cover
404+
raiseRuntimeError(
407405
"Unexpected empty sequence (len(possibles)=%d, "
408406
"len(done)=%d, len(nodes)=%d). This is usually due to "
409407
"a name used both as result name and node node. "
@@ -434,7 +432,7 @@ def _find_sequence(node_name, known, done):
434432
best=k
435433

436434
ifbestisNone:
437-
raiseRuntimeError(# pragma: no cover
435+
raiseRuntimeError(
438436
f"Wrong implementation (len(sequence)={len(sequences)})."
439437
)
440438
ifverbose:
@@ -453,7 +451,7 @@ def _find_sequence(node_name, known, done):
453451
known|=set(v.output)
454452

455453
iflen(new_nodes)!=len(nodes):
456-
raiseRuntimeError(# pragma: no cover
454+
raiseRuntimeError(
457455
"The returned new nodes are different. "
458456
"len(nodes=%d) != %d=len(new_nodes). done=\n%r"
459457
"\n%s\n----------\n%s"
@@ -486,7 +484,7 @@ def _find_sequence(node_name, known, done):
486484
n0s=set(n.nameforninnodes)
487485
n1s=set(n.nameforninnew_nodes)
488486
ifn0s!=n1s:
489-
raiseRuntimeError(# pragma: no cover
487+
raiseRuntimeError(
490488
"The returned new nodes are different.\n"
491489
"%r !=\n%r\ndone=\n%r"
492490
"\n----------\n%s\n----------\n%s"
@@ -758,7 +756,7 @@ def str_node(indent, node):
758756
try:
759757
val=str(to_array(att.t).tolist())
760758
exceptTypeErrorase:
761-
raiseTypeError(# pragma: no cover
759+
raiseTypeError(
762760
"Unable to display tensor type %r.\n%s"
763761
% (att.type,str(att))
764762
)frome
@@ -853,9 +851,7 @@ def str_node(indent, node):
853851
ifisinstance(att,str):
854852
rows.append(f"attribute:{att!r}")
855853
else:
856-
raiseNotImplementedError(# pragma: no cover
857-
"Not yet introduced in onnx."
858-
)
854+
raiseNotImplementedError("Not yet introduced in onnx.")
859855

860856
# initializer
861857
ifhasattr(model,"initializer"):
@@ -894,7 +890,7 @@ def str_node(indent, node):
894890

895891
try:
896892
nodes=reorder_nodes_for_display(model.node,verbose=verbose)
897-
exceptRuntimeErrorase:# pragma: no cover
893+
exceptRuntimeErrorase:
898894
ifraise_exc:
899895
raisee
900896
else:
@@ -924,9 +920,7 @@ def str_node(indent, node):
924920
indent=mi
925921
ifprevious_indentisnotNoneandindent<previous_indent:
926922
ifverbose:
927-
print(# pragma: no cover
928-
f"[onnx_simple_text_plot] break2{node.op_type}"
929-
)
923+
print(f"[onnx_simple_text_plot] break2{node.op_type}")
930924
add_break=True
931925
ifnotadd_breakandprevious_outisnotNone:
932926
ifnot (set(node.input)&previous_out):

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp