Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit1b49fee

Browse files
committed
export to builder
1 parenta54de21 commit1b49fee

File tree

6 files changed

+278
-12
lines changed

6 files changed

+278
-12
lines changed

‎_unittests/ut_translate_api/test_translate.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,5 +221,4 @@ def test_aionnxml(self):
221221

222222

223223
if__name__=="__main__":
224-
TestTranslate().test_export_if()
225224
unittest.main(verbosity=2)
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
importunittest
2+
fromtextwrapimportdedent
3+
importnumpyasnp
4+
fromonnximportModelProto,TensorProto
5+
fromonnx.defsimportonnx_opset_version
6+
fromonnx.referenceimportReferenceEvaluator
7+
fromonnx_array_api.ext_test_caseimportExtTestCase
8+
fromonnx_array_api.light_apiimportstart
9+
fromonnx_array_api.graph_apiimportGraphBuilder
10+
fromonnx_array_api.translate_apiimporttranslate
11+
12+
13+
OPSET_API=min(19,onnx_opset_version()-1)
14+
15+
16+
classTestTranslateBuilder(ExtTestCase):
17+
defsetUp(self):
18+
self.maxDiff=None
19+
20+
deftest_exp(self):
21+
onx=start(opset=19).vin("X").Exp().rename("Y").vout().to_onnx()
22+
self.assertIsInstance(onx,ModelProto)
23+
self.assertIn("Exp",str(onx))
24+
ref=ReferenceEvaluator(onx)
25+
a=np.arange(10).astype(np.float32)
26+
got=ref.run(None, {"X":a})[0]
27+
self.assertEqualArray(np.exp(a),got)
28+
29+
code=translate(onx,api="builder")
30+
expected=dedent(
31+
"""
32+
def light_api(
33+
op: "GraphBuilder",
34+
X: "FLOAT[]",
35+
):
36+
Y = op.Exp(X)
37+
op.Identity(Y, outputs=["Y"])
38+
return Y
39+
40+
g = GraphBuilder({'': 19})
41+
g.make_tensor_input("X", TensorProto.FLOAT, ())
42+
light_api(g.op, X)
43+
g.make_tensor_output("Y", TensorProto.FLOAT, ())
44+
model = g.to_onnx()
45+
"""
46+
).strip("\n")
47+
self.assertEqual(expected,code.strip("\n"))
48+
49+
deflight_api(
50+
op:"GraphBuilder",
51+
X:"FLOAT[]",# noqa: F722
52+
):
53+
Y=op.Exp(X)
54+
op.Identity(Y,outputs=["Y"])
55+
returnY
56+
57+
g2=GraphBuilder({"":19})
58+
g2.make_tensor_input("X",TensorProto.FLOAT, ("A",))
59+
light_api(g2.op,"X")
60+
g2.make_tensor_output("Y",TensorProto.FLOAT, ("A",))
61+
onx2=g2.to_onnx()
62+
63+
ref=ReferenceEvaluator(onx2)
64+
a=np.arange(10).astype(np.float32)
65+
got=ref.run(None, {"X":a})[0]
66+
self.assertEqualArray(np.exp(a),got)
67+
68+
69+
if__name__=="__main__":
70+
unittest.main(verbosity=2)

‎onnx_array_api/translate_api/__init__.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
fromonnximportModelProto
22
from .translateimportTranslater
33
from .inner_emitterimportInnerEmitter
4+
from .builder_emitterimportBuilderEmitter
45

56

67
deftranslate(proto:ModelProto,single_line:bool=False,api:str="light")->str:
@@ -14,7 +15,8 @@ def translate(proto: ModelProto, single_line: bool = False, api: str = "light")
1415
default is `"light"` and this is handle by class
1516
:class:`onnx_array_api.translate_api.light_emitter.LightEmitter`,
1617
another value is `"onnx"` which is the inner API implemented
17-
in onnx package.
18+
in onnx package, `"builder"` follows the syntax for the
19+
class :class:`onnx_array_api.graph_api.GraphBuilder`
1820
:return: code
1921
2022
.. runpython::
@@ -35,7 +37,7 @@ def translate(proto: ModelProto, single_line: bool = False, api: str = "light")
3537
code = translate(onx)
3638
print(code)
3739
38-
The inner API from onnxpackahe is also available.
40+
The inner API from onnxpackage is also available.
3941
4042
.. runpython::
4143
:showcode:
@@ -54,11 +56,35 @@ def translate(proto: ModelProto, single_line: bool = False, api: str = "light")
5456
)
5557
code = translate(onx, api="onnx")
5658
print(code)
59+
60+
The :class:`GraphBuilder
61+
<onnx_array_api.graph_api.GraphBuilder>` API returns this:
62+
63+
.. runpython::
64+
:showcode:
65+
66+
from onnx_array_api.light_api import start
67+
from onnx_array_api.translate_api import translate
68+
69+
onx = (
70+
start()
71+
.vin("X")
72+
.reshape((-1, 1))
73+
.Transpose(perm=[1, 0])
74+
.rename("Y")
75+
.vout()
76+
.to_onnx()
77+
)
78+
code = translate(onx, api="builder")
79+
print(code)
5780
"""
5881
ifapi=="light":
5982
tr=Translater(proto)
6083
returntr.export(single_line=single_line,as_str=True)
6184
ifapi=="onnx":
6285
tr=Translater(proto,emitter=InnerEmitter())
6386
returntr.export(as_str=True)
87+
ifapi=="builder":
88+
tr=Translater(proto,emitter=BuilderEmitter())
89+
returntr.export(as_str=True)
6490
raiseValueError(f"Unexpected value{api!r} for api.")

‎onnx_array_api/translate_api/base_emitter.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ class EventType(IntEnum):
2121
FUNCTION_OUTPUT=12
2222
FUNCTION_ATTRIBUTES=13
2323
TO_ONNX_FUNCTION=14
24+
BEGIN_SIGNATURE=15
25+
END_SIGNATURE=16
26+
BEGIN_RETURN=17
27+
END_RETURN=18
2428

2529
@classmethod
2630
defto_str(cls,self)->str:
@@ -84,6 +88,18 @@ def __call__(self, event: EventType, **kwargs: Dict[str, Any]) -> List[str]:
8488
ifevent==EventType.FUNCTION_ATTRIBUTES:
8589
returnself._emit_function_attributes(**kwargs)
8690

91+
ifevent==EventType.BEGIN_SIGNATURE:
92+
returnself._emit_begin_signature(**kwargs)
93+
94+
ifevent==EventType.END_SIGNATURE:
95+
returnself._emit_end_signature(**kwargs)
96+
97+
ifevent==EventType.BEGIN_RETURN:
98+
returnself._emit_begin_return(**kwargs)
99+
100+
ifevent==EventType.END_RETURN:
101+
returnself._emit_end_return(**kwargs)
102+
87103
raiseValueError(f"Unexpected event{EventType.to_str(event)}.")
88104

89105
defrender_attribute_value(self,value:Any)->Tuple[List[str],str]:
@@ -222,3 +238,15 @@ def _emit_function_attributes(self, **kwargs: Dict[str, Any]) -> List[str]:
222238
raiseNotImplementedError(
223239
f"Method{inspect.currentframe().f_code.co_name!r} was not overloaded."
224240
)
241+
242+
def_emit_begin_signature(self,**kwargs:Dict[str,Any])->List[str]:
243+
return []
244+
245+
def_emit_end_signature(self,**kwargs:Dict[str,Any])->List[str]:
246+
return []
247+
248+
def_emit_begin_return(self,**kwargs:Dict[str,Any])->List[str]:
249+
return []
250+
251+
def_emit_end_return(self,**kwargs:Dict[str,Any])->List[str]:
252+
return []
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
fromtypingimportAny,Dict,List
2+
fromonnximportTensorProto
3+
from .base_emitterimportBaseEmitter
4+
5+
_types= {
6+
TensorProto.FLOAT:"FLOAT",
7+
TensorProto.FLOAT16:"FLOAT16",
8+
TensorProto.INT64:"INT64",
9+
TensorProto.INT32:"INT32",
10+
}
11+
12+
13+
def_itype_to_string(itype:int)->str:
14+
return_types[itype]
15+
16+
17+
classBuilderEmitter(BaseEmitter):
18+
"""
19+
Converts event into proper code.
20+
"""
21+
22+
defjoin(self,rows:List[str],single_line:bool=False)->str:
23+
"Join the rows"
24+
assert (
25+
notsingle_line
26+
),f"The emitter{type(self)} does not work with single_line=True."
27+
return"\n".join(rows)
28+
29+
def_emit_start(self,**kwargs:Dict[str,Any])->List[str]:
30+
self.opsets=kwargs.get("opsets", {})
31+
return []
32+
33+
def_emit_to_onnx_model(self,**kwargs:Dict[str,Any])->List[str]:
34+
inps=", ".join(["g.op",*self.inputs])
35+
inputs= []
36+
forinp,stype,shapeinself.inputs_full_:
37+
inputs.append(f'g.make_tensor_input("{inp}", TensorProto.{stype},{shape})')
38+
outputs= []
39+
forinp,stype,shapeinself.outputs_full_:
40+
outputs.append(
41+
f'g.make_tensor_output("{inp}", TensorProto.{stype},{shape})'
42+
)
43+
rows= [
44+
"",
45+
f"g = GraphBuilder({self.opsets})",
46+
*inputs,
47+
f"{self.name}({inps})",
48+
*outputs,
49+
"model = g.to_onnx()",
50+
]
51+
returnrows
52+
53+
def_emit_begin_graph(self,**kwargs:Dict[str,Any])->List[str]:
54+
self.inputs= []
55+
self.inputs_full= []
56+
self.outputs= []
57+
self.inits= []
58+
self.inputs_full_= []
59+
self.outputs_full_= []
60+
self.name=kwargs.get("name","make_graph")
61+
return []
62+
63+
def_emit_end_graph(self,**kwargs:Dict[str,Any])->List[str]:
64+
return []
65+
66+
def_emit_initializer(self,**kwargs:Dict[str,Any])->List[str]:
67+
assertFalse,f"not implemented yet with{kwargs}"
68+
69+
def_emit_input(self,**kwargs:Dict[str,Any])->List[str]:
70+
name=kwargs["name"]
71+
itype=kwargs.get("elem_type",0)
72+
shape=kwargs.get("shape",None)
73+
ifitype==0:
74+
inp="X"
75+
else:
76+
ifshapeisNone:
77+
inp=f'X: "{_itype_to_string(itype)}"'
78+
else:
79+
inp=f'X: "{_itype_to_string(itype)}[{", ".join(map(str,shape))}]"'
80+
self.inputs_full.append(inp)
81+
self.inputs.append(name)
82+
self.inputs_full_.append((name,_itype_to_string(itype),shape))
83+
return []
84+
85+
def_emit_begin_signature(self,**kwargs:Dict[str,Any])->List[str]:
86+
return []
87+
88+
def_emit_end_signature(self,**kwargs:Dict[str,Any])->List[str]:
89+
rows= ["",f"def{self.name}(",' op: "GraphBuilder",']
90+
foriinself.inputs_full:
91+
rows.append(f"{i},")
92+
rows.append("):")
93+
returnrows
94+
95+
def_emit_begin_return(self,**kwargs:Dict[str,Any])->List[str]:
96+
return []
97+
98+
def_emit_end_return(self,**kwargs:Dict[str,Any])->List[str]:
99+
outs=", ".join(self.outputs)
100+
return [f" return{outs}"]
101+
102+
def_emit_output(self,**kwargs:Dict[str,Any])->List[str]:
103+
name=kwargs["name"]
104+
itype=kwargs.get("elem_type",0)
105+
shape=kwargs.get("shape",None)
106+
self.outputs.append(name)
107+
self.outputs_full_.append((name,_itype_to_string(itype),shape))
108+
return [f' op.Identity({name}, outputs=["{name}"])']
109+
110+
def_emit_node(self,**kwargs:Dict[str,Any])->List[str]:
111+
op_type=kwargs["op_type"]
112+
inputs=kwargs["inputs"]
113+
outputs=kwargs["outputs"]
114+
ifkwargs.get("domain","")!="":
115+
domain=kwargs["domain"]
116+
op_type=f"{domain}.{op_type}"
117+
atts=kwargs.get("atts", {})
118+
args= []
119+
fork,vinatts.items():
120+
before,vatt=self.render_attribute_value(v)
121+
ifbefore:
122+
raiseNotImplementedError("Graph attribute not supported yet.")
123+
args.append(f"{k}={vatt}")
124+
125+
outs=", ".join(outputs)
126+
inps=", ".join(inputs)
127+
ifargs:
128+
sargs=", ".join(args)
129+
row=f"{outs} = op.{op_type}({inps},{sargs})"
130+
else:
131+
row=f"{outs} = op.{op_type}({inps})"
132+
return [row]

‎onnx_array_api/translate_api/translate.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,18 +76,12 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
7676
)
7777
)
7878
else:
79-
rows.extend(self.emitter(EventType.BEGIN_GRAPH))
80-
81-
foriininitializers:
8279
rows.extend(
83-
self.emitter(
84-
EventType.INITIALIZER,
85-
name=i.name,
86-
init=i,
87-
value=to_array_extended(i),
88-
)
80+
self.emitter(EventType.BEGIN_GRAPH,name=self.proto_.graph.name)
8981
)
9082

83+
rows.extend(self.emitter(EventType.BEGIN_SIGNATURE))
84+
9185
foriininputs:
9286
ifis_function:
9387
rows.extend(self.emitter(EventType.FUNCTION_INPUT,name=i))
@@ -109,6 +103,18 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
109103
self.emitter(EventType.FUNCTION_ATTRIBUTES,attributes=list(attributes))
110104
)
111105

106+
rows.extend(self.emitter(EventType.END_SIGNATURE))
107+
108+
foriininitializers:
109+
rows.extend(
110+
self.emitter(
111+
EventType.INITIALIZER,
112+
name=i.name,
113+
init=i,
114+
value=to_array_extended(i),
115+
)
116+
)
117+
112118
fornodeinnodes:
113119
atts=self.extract_attributes(node)
114120
rows.extend(
@@ -122,6 +128,8 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
122128
)
123129
)
124130

131+
rows.extend(self.emitter(EventType.BEGIN_RETURN))
132+
125133
foroinoutputs:
126134
ifis_function:
127135
rows.extend(self.emitter(EventType.FUNCTION_OUTPUT,name=o))
@@ -137,6 +145,9 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
137145
),
138146
)
139147
)
148+
149+
rows.extend(self.emitter(EventType.END_RETURN))
150+
140151
ifisinstance(self.proto_, (GraphProto,FunctionProto)):
141152
name=self.proto_.name
142153
else:

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp