Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit53506d1

Browse files
authored
First draft to export to GraphBuilder (#83)
* export to builder* doc* fix unit test* fix order* fix initializer* fix ut* fix opset
1 parenta54de21 commit53506d1

File tree

8 files changed

+354
-5
lines changed

8 files changed

+354
-5
lines changed

‎CHANGELOGS.rst‎

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
Change Logs
22
===========
33

4-
0.2.0
4+
0.3.0
55
+++++
66

7+
*:pr:`79`: first draft to export to GraphBuilder
78
*:pr:`77`: supports ConcatOfShape and Slice with the light API
9+
10+
0.2.0
11+
+++++
12+
813
*:pr:`76`,:pr:`79`: add a mode to compare models without execution
914
*:pr:`75`: add QuickGelu to ExtendedReferenceEvaluator
1015
*:pr:`71`: adds tools to compare two onnx graphs

‎_unittests/ut_translate_api/test_translate.py‎

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,5 +221,4 @@ def test_aionnxml(self):
221221

222222

223223
if__name__=="__main__":
224-
TestTranslate().test_export_if()
225224
unittest.main(verbosity=2)
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
importunittest
2+
fromtextwrapimportdedent
3+
importnumpyasnp
4+
fromonnximportModelProto,TensorProto
5+
fromonnx.checkerimportcheck_model
6+
fromonnx.defsimportonnx_opset_version
7+
fromonnx.referenceimportReferenceEvaluator
8+
fromonnx_array_api.ext_test_caseimportExtTestCase
9+
fromonnx_array_api.light_apiimportstart
10+
fromonnx_array_api.graph_apiimportGraphBuilder
11+
fromonnx_array_api.translate_apiimporttranslate
12+
13+
14+
OPSET_API=min(19,onnx_opset_version()-1)
15+
16+
17+
classTestTranslateBuilder(ExtTestCase):
18+
defsetUp(self):
19+
self.maxDiff=None
20+
21+
deftest_exp(self):
22+
onx=start(opset=19).vin("X").Exp().rename("Y").vout().to_onnx()
23+
self.assertIsInstance(onx,ModelProto)
24+
self.assertIn("Exp",str(onx))
25+
ref=ReferenceEvaluator(onx)
26+
a=np.arange(10).astype(np.float32)
27+
got=ref.run(None, {"X":a})[0]
28+
self.assertEqualArray(np.exp(a),got)
29+
30+
code=translate(onx,api="builder")
31+
expected=dedent(
32+
"""
33+
def light_api(
34+
op: "GraphBuilder",
35+
X: "FLOAT[]",
36+
):
37+
Y = op.Exp(X)
38+
op.Identity(Y, outputs=["Y"])
39+
return Y
40+
41+
g = GraphBuilder({'': 19})
42+
g.make_tensor_input("X", TensorProto.FLOAT, ())
43+
light_api(g.op, "X")
44+
g.make_tensor_output("Y", TensorProto.FLOAT, ())
45+
model = g.to_onnx()
46+
"""
47+
).strip("\n")
48+
self.assertEqual(expected,code.strip("\n"))
49+
50+
deflight_api(
51+
op:"GraphBuilder",
52+
X:"FLOAT[]",# noqa: F722
53+
):
54+
Y=op.Exp(X)
55+
op.Identity(Y,outputs=["Y"])
56+
returnY
57+
58+
g2=GraphBuilder({"":19})
59+
g2.make_tensor_input("X",TensorProto.FLOAT, ("A",))
60+
light_api(g2.op,"X")
61+
g2.make_tensor_output("Y",TensorProto.FLOAT, ("A",))
62+
onx2=g2.to_onnx()
63+
64+
ref=ReferenceEvaluator(onx2)
65+
a=np.arange(10).astype(np.float32)
66+
got=ref.run(None, {"X":a})[0]
67+
self.assertEqualArray(np.exp(a),got)
68+
69+
deftest_zdoc(self):
70+
onx= (
71+
start(opset=19)
72+
.vin("X")
73+
.reshape((-1,1))
74+
.Transpose(perm=[1,0])
75+
.rename("Y")
76+
.vout()
77+
.to_onnx()
78+
)
79+
code=translate(onx,api="builder")
80+
expected=dedent(
81+
"""
82+
def light_api(
83+
op: "GraphBuilder",
84+
X: "FLOAT[]",
85+
):
86+
r = np.array([-1, 1], dtype=np.int64)
87+
r0_0 = op.Reshape(X, r)
88+
Y = op.Transpose(r0_0, perm=[1, 0])
89+
op.Identity(Y, outputs=["Y"])
90+
return Y
91+
92+
g = GraphBuilder({'': 19})
93+
g.make_tensor_input("X", TensorProto.FLOAT, ())
94+
light_api(g.op, "X")
95+
g.make_tensor_output("Y", TensorProto.FLOAT, ())
96+
model = g.to_onnx()
97+
"""
98+
).strip("\n")
99+
self.maxDiff=None
100+
self.assertEqual(expected,code.strip("\n"))
101+
102+
deflight_api(
103+
op:"GraphBuilder",
104+
X:"FLOAT[]",# noqa: F722
105+
):
106+
r=np.array([-1,1],dtype=np.int64)
107+
r0_0=op.Reshape(X,r)
108+
Y=op.Transpose(r0_0,perm=[1,0])
109+
op.Identity(Y,outputs=["Y"])
110+
returnY
111+
112+
g=GraphBuilder({"":21})
113+
X=g.make_tensor_input("X",TensorProto.FLOAT, ())
114+
light_api(g.op,X)
115+
g.make_tensor_output("Y",TensorProto.FLOAT, ())
116+
model=g.to_onnx()
117+
self.assertNotEmpty(model)
118+
check_model(model)
119+
120+
121+
if__name__=="__main__":
122+
unittest.main(verbosity=2)

‎onnx_array_api/graph_api/graph_builder.py‎

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,18 @@ def __getattr__(self, name):
119119
exceptAttributeErrorase:
120120
raiseAttributeError(f"Unable to access attribute{name!r}.")frome
121121

122+
defInitializer(
123+
self,init:Union[TensorProto,np.ndarray],name:Optional[str]=None
124+
)->str:
125+
"""
126+
Creates an initializer.
127+
128+
:param init: value
129+
:param name: name if value is not a TensorProto
130+
:return: its name
131+
"""
132+
returnself.builder.make_initializer(init,name=name,exists=True)
133+
122134
defmake_node(
123135
self,
124136
op_type:str,

‎onnx_array_api/translate_api/__init__.py‎

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
fromonnximportModelProto
22
from .translateimportTranslater
33
from .inner_emitterimportInnerEmitter
4+
from .builder_emitterimportBuilderEmitter
45

56

67
deftranslate(proto:ModelProto,single_line:bool=False,api:str="light")->str:
@@ -14,7 +15,8 @@ def translate(proto: ModelProto, single_line: bool = False, api: str = "light")
1415
default is `"light"` and this is handle by class
1516
:class:`onnx_array_api.translate_api.light_emitter.LightEmitter`,
1617
another value is `"onnx"` which is the inner API implemented
17-
in onnx package.
18+
in onnx package, `"builder"` follows the syntax for the
19+
class :class:`onnx_array_api.graph_api.GraphBuilder`
1820
:return: code
1921
2022
.. runpython::
@@ -35,7 +37,7 @@ def translate(proto: ModelProto, single_line: bool = False, api: str = "light")
3537
code = translate(onx)
3638
print(code)
3739
38-
The inner API from onnxpackahe is also available.
40+
The inner API from onnxpackage is also available.
3941
4042
.. runpython::
4143
:showcode:
@@ -54,11 +56,35 @@ def translate(proto: ModelProto, single_line: bool = False, api: str = "light")
5456
)
5557
code = translate(onx, api="onnx")
5658
print(code)
59+
60+
The :class:`GraphBuilder
61+
<onnx_array_api.graph_api.GraphBuilder>` API returns this:
62+
63+
.. runpython::
64+
:showcode:
65+
66+
from onnx_array_api.light_api import start
67+
from onnx_array_api.translate_api import translate
68+
69+
onx = (
70+
start()
71+
.vin("X")
72+
.reshape((-1, 1))
73+
.Transpose(perm=[1, 0])
74+
.rename("Y")
75+
.vout()
76+
.to_onnx()
77+
)
78+
code = translate(onx, api="builder")
79+
print(code)
5780
"""
5881
ifapi=="light":
5982
tr=Translater(proto)
6083
returntr.export(single_line=single_line,as_str=True)
6184
ifapi=="onnx":
6285
tr=Translater(proto,emitter=InnerEmitter())
6386
returntr.export(as_str=True)
87+
ifapi=="builder":
88+
tr=Translater(proto,emitter=BuilderEmitter())
89+
returntr.export(as_str=True)
6490
raiseValueError(f"Unexpected value{api!r} for api.")

‎onnx_array_api/translate_api/base_emitter.py‎

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ class EventType(IntEnum):
2121
FUNCTION_OUTPUT=12
2222
FUNCTION_ATTRIBUTES=13
2323
TO_ONNX_FUNCTION=14
24+
BEGIN_SIGNATURE=15
25+
END_SIGNATURE=16
26+
BEGIN_RETURN=17
27+
END_RETURN=18
2428

2529
@classmethod
2630
defto_str(cls,self)->str:
@@ -84,6 +88,18 @@ def __call__(self, event: EventType, **kwargs: Dict[str, Any]) -> List[str]:
8488
ifevent==EventType.FUNCTION_ATTRIBUTES:
8589
returnself._emit_function_attributes(**kwargs)
8690

91+
ifevent==EventType.BEGIN_SIGNATURE:
92+
returnself._emit_begin_signature(**kwargs)
93+
94+
ifevent==EventType.END_SIGNATURE:
95+
returnself._emit_end_signature(**kwargs)
96+
97+
ifevent==EventType.BEGIN_RETURN:
98+
returnself._emit_begin_return(**kwargs)
99+
100+
ifevent==EventType.END_RETURN:
101+
returnself._emit_end_return(**kwargs)
102+
87103
raiseValueError(f"Unexpected event{EventType.to_str(event)}.")
88104

89105
defrender_attribute_value(self,value:Any)->Tuple[List[str],str]:
@@ -222,3 +238,15 @@ def _emit_function_attributes(self, **kwargs: Dict[str, Any]) -> List[str]:
222238
raiseNotImplementedError(
223239
f"Method{inspect.currentframe().f_code.co_name!r} was not overloaded."
224240
)
241+
242+
def_emit_begin_signature(self,**kwargs:Dict[str,Any])->List[str]:
243+
return []
244+
245+
def_emit_end_signature(self,**kwargs:Dict[str,Any])->List[str]:
246+
return []
247+
248+
def_emit_begin_return(self,**kwargs:Dict[str,Any])->List[str]:
249+
return []
250+
251+
def_emit_end_return(self,**kwargs:Dict[str,Any])->List[str]:
252+
return []

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp