Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit664e084

Browse files
authored
Improves translation to GraphBuilder (#95)
* Improves translation to GraphBuilder* ch* fix issue* ir* urls* check
1 parent689cc6f commit664e084

File tree

6 files changed

+127
-17
lines changed

6 files changed

+127
-17
lines changed

‎.github/workflows/check-urls.yml‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,6 @@ jobs:
4242
print_all:false
4343
timeout:2
4444
retry_count# : 2
45-
exclude_urls:https://hal.archives-ouvertes.fr/hal-00990252/document
46-
exclude_patterns:https://www.data.gouv.fr/fr/datasets/r/e3d83ab3-dc52-4c99-abaf-8a38050cc68c,https://dev.azure.com/
45+
exclude_urls:https://hal.archives-ouvertes.fr/hal-00990252/document,https://github.com/onnx/tensorflow-onnx
46+
exclude_patterns:https://www.data.gouv.fr/fr/datasets/r/e3d83ab3-dc52-4c99-abaf-8a38050cc68c,https://dev.azure.com/,https://github.com/onnx/tensorflow-onnx
4747
# force_pass : true

‎CHANGELOGS.rst‎

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
Change Logs
22
===========
33

4+
0.3.1
5+
+++++
6+
7+
*:pr:`95`: improves translation to GraphBuilder
8+
49
0.3.0
510
+++++
611

‎_unittests/ut_translate_api/test_translate_builder.py‎

Lines changed: 62 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
fromonnx_array_api.ext_test_caseimportExtTestCase
99
fromonnx_array_api.light_apiimportstart
1010
fromonnx_array_api.graph_apiimportGraphBuilder
11-
fromonnx_array_api.translate_apiimporttranslate
11+
fromonnx_array_api.translate_apiimporttranslate,Translater
12+
fromonnx_array_api.translate_api.builder_emitterimportBuilderEmitter
1213

1314

1415
OPSET_API=min(19,onnx_opset_version()-1)
@@ -19,7 +20,7 @@ def setUp(self):
1920
self.maxDiff=None
2021

2122
deftest_exp(self):
22-
onx=start(opset=19).vin("X").Exp().rename("Y").vout().to_onnx()
23+
onx=start(opset=19,ir_version=10).vin("X").Exp().rename("Y").vout().to_onnx()
2324
self.assertIsInstance(onx,ModelProto)
2425
self.assertIn("Exp",str(onx))
2526
ref=ReferenceEvaluator(onx)
@@ -38,7 +39,7 @@ def light_api(
3839
op.Identity(Y, outputs=["Y"])
3940
return Y
4041
41-
g = GraphBuilder({'': 19})
42+
g = GraphBuilder({'': 19}, ir_version=10)
4243
g.make_tensor_input("X", TensorProto.FLOAT, ())
4344
light_api(g.op, "X")
4445
g.make_tensor_output("Y", TensorProto.FLOAT, ())
@@ -68,7 +69,7 @@ def light_api(
6869

6970
deftest_zdoc(self):
7071
onx= (
71-
start(opset=19)
72+
start(opset=19,ir_version=10)
7273
.vin("X")
7374
.reshape((-1,1))
7475
.Transpose(perm=[1,0])
@@ -89,7 +90,7 @@ def light_api(
8990
op.Identity(Y, outputs=["Y"])
9091
return Y
9192
92-
g = GraphBuilder({'': 19})
93+
g = GraphBuilder({'': 19}, ir_version=10)
9394
g.make_tensor_input("X", TensorProto.FLOAT, ())
9495
light_api(g.op, "X")
9596
g.make_tensor_output("Y", TensorProto.FLOAT, ())
@@ -117,6 +118,62 @@ def light_api(
117118
self.assertNotEmpty(model)
118119
check_model(model)
119120

121+
deftest_exp_f(self):
122+
onx=start(opset=19,ir_version=10).vin("X").Exp().rename("Y").vout().to_onnx()
123+
self.assertIsInstance(onx,ModelProto)
124+
self.assertIn("Exp",str(onx))
125+
ref=ReferenceEvaluator(onx)
126+
a=np.arange(10).astype(np.float32)
127+
got=ref.run(None, {"X":a})[0]
128+
self.assertEqualArray(np.exp(a),got)
129+
130+
tr=Translater(onx,emitter=BuilderEmitter("mm"))
131+
code=tr.export(as_str=True)
132+
133+
expected=dedent(
134+
"""
135+
def light_api(
136+
op: "GraphBuilder",
137+
X: "FLOAT[]",
138+
):
139+
Y = op.Exp(X)
140+
op.Identity(Y, outputs=["Y"])
141+
return Y
142+
143+
144+
def mm() -> "ModelProto":
145+
g = GraphBuilder({'': 19}, ir_version=10)
146+
g.make_tensor_input("X", TensorProto.FLOAT, ())
147+
light_api(g.op, "X")
148+
g.make_tensor_output("Y", TensorProto.FLOAT, ())
149+
model = g.to_onnx()
150+
return model
151+
152+
153+
model = mm()
154+
"""
155+
).strip("\n")
156+
self.assertEqual(expected,code.strip("\n"))
157+
158+
deflight_api(
159+
op:"GraphBuilder",
160+
X:"FLOAT[]",# noqa: F722
161+
):
162+
Y=op.Exp(X)
163+
op.Identity(Y,outputs=["Y"])
164+
returnY
165+
166+
g2=GraphBuilder({"":19})
167+
g2.make_tensor_input("X",TensorProto.FLOAT, ("A",))
168+
light_api(g2.op,"X")
169+
g2.make_tensor_output("Y",TensorProto.FLOAT, ("A",))
170+
onx2=g2.to_onnx()
171+
172+
ref=ReferenceEvaluator(onx2)
173+
a=np.arange(10).astype(np.float32)
174+
got=ref.run(None, {"X":a})[0]
175+
self.assertEqualArray(np.exp(a),got)
176+
120177

121178
if__name__=="__main__":
122179
unittest.main(verbosity=2)

‎onnx_array_api/__init__.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
APIs to create ONNX Graphs.
33
"""
44

5-
__version__="0.3.0"
5+
__version__="0.3.1"
66
__author__="Xavier Dupré"

‎onnx_array_api/translate_api/builder_emitter.py‎

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,17 @@
44
from .base_emitterimportBaseEmitter
55

66
_types= {
7+
TensorProto.DOUBLE:"DOUBLE",
78
TensorProto.FLOAT:"FLOAT",
89
TensorProto.FLOAT16:"FLOAT16",
910
TensorProto.INT64:"INT64",
1011
TensorProto.INT32:"INT32",
12+
TensorProto.INT16:"INT16",
13+
TensorProto.UINT64:"UINT64",
14+
TensorProto.UINT32:"UINT32",
15+
TensorProto.UINT16:"UINT16",
16+
TensorProto.STRING:"STRING",
17+
TensorProto.BOOL:"BOOL",
1118
}
1219

1320

@@ -20,6 +27,10 @@ class BuilderEmitter(BaseEmitter):
2027
Converts event into proper code.
2128
"""
2229

30+
def__init__(self,make_model_function:str=""):
31+
super().__init__()
32+
self.make_model_function=make_model_function
33+
2334
defjoin(self,rows:List[str],single_line:bool=False)->str:
2435
"Join the rows"
2536
assert (
@@ -29,6 +40,7 @@ def join(self, rows: List[str], single_line: bool = False) -> str:
2940

3041
def_emit_start(self,**kwargs:Dict[str,Any])->List[str]:
3142
self.opsets=kwargs.get("opsets", {})
43+
self.ir_version=kwargs.get("ir_version",None)
3244
return []
3345

3446
def_emit_to_onnx_model(self,**kwargs:Dict[str,Any])->List[str]:
@@ -43,12 +55,27 @@ def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]:
4355
)
4456
rows= [
4557
"",
46-
f"g = GraphBuilder({self.opsets})",
58+
(
59+
f"g = GraphBuilder({self.opsets}, ir_version={self.ir_version})"
60+
ifself.ir_version
61+
elsef"GraphBuilder({self.opsets})"
62+
),
4763
*inputs,
4864
f"{self.name}({inps})",
4965
*outputs,
5066
"model = g.to_onnx()",
5167
]
68+
ifself.make_model_function:
69+
rows= [
70+
"",
71+
"",
72+
f'def{self.make_model_function}() -> "ModelProto":',
73+
*[" "+_for_inrows[1:]],
74+
" return model",
75+
"",
76+
"",
77+
f"model ={self.make_model_function}()",
78+
]
5279
returnrows
5380

5481
def_emit_begin_graph(self,**kwargs:Dict[str,Any])->List[str]:
@@ -78,13 +105,16 @@ def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]:
78105
name=kwargs["name"]
79106
itype=kwargs.get("elem_type",0)
80107
shape=kwargs.get("shape",None)
108+
name=self._clean_result_name(name)
81109
ifitype==0:
82-
inp="X"
110+
inp=nameor"X"
83111
else:
84112
ifshapeisNone:
85-
inp=f'X: "{_itype_to_string(itype)}"'
113+
inp=f'{name}: "{_itype_to_string(itype)}"'
86114
else:
87-
inp=f'X: "{_itype_to_string(itype)}[{", ".join(map(str,shape))}]"'
115+
inp= (
116+
f'{name}: "{_itype_to_string(itype)}[{", ".join(map(str,shape))}]"'
117+
)
88118
self.inputs_full.append(inp)
89119
self.inputs.append(name)
90120
self.inputs_full_.append((name,_itype_to_string(itype),shape))
@@ -113,6 +143,7 @@ def _emit_end_return(self, **kwargs: Dict[str, Any]) -> List[str]:
113143

114144
def_emit_output(self,**kwargs:Dict[str,Any])->List[str]:
115145
name=kwargs["name"]
146+
name=self._clean_result_name(name)
116147
itype=kwargs.get("elem_type",0)
117148
shape=kwargs.get("shape",None)
118149
self.outputs.append(name)
@@ -126,6 +157,8 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]:
126157
ifkwargs.get("domain","")!="":
127158
domain=kwargs["domain"]
128159
op_type=f"{domain}.{op_type}"
160+
else:
161+
domain=""
129162
atts=kwargs.get("atts", {})
130163
args= []
131164
fork,vinatts.items():
@@ -134,11 +167,22 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]:
134167
raiseNotImplementedError("Graph attribute not supported yet.")
135168
args.append(f"{k}={vatt}")
136169

137-
outs=", ".join(outputs)
138-
inps=", ".join(inputs)
170+
outs=", ".join(map(self._clean_result_name,outputs))
171+
inps=", ".join(map(self._clean_result_name,inputs))
172+
op_type=self._emit_node_type(op_type,domain)
173+
sdomain=""ifnotdomainelsef", domain={domain!r}"
139174
ifargs:
140175
sargs=", ".join(args)
141-
row=f"{outs} = op.{op_type}({inps},{sargs})"
176+
ifinps:
177+
row=f"{outs} = op.{op_type}({inps},{sargs}{sdomain})"
178+
else:
179+
row=f"{outs} = op.{op_type}({sargs}{sdomain})"
142180
else:
143-
row=f"{outs} = op.{op_type}({inps})"
181+
row=f"{outs} = op.{op_type}({inps}{sdomain})"
144182
return [row]
183+
184+
def_clean_result_name(self,name):
185+
returnname
186+
187+
def_emit_node_type(self,op_type,domain):
188+
returnop_type

‎onnx_array_api/translate_api/translate.py‎

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
3535
last_event=None
3636
ifisinstance(self.proto_,ModelProto):
3737
opsets= {d.domain:d.versionfordinself.proto_.opset_import}
38-
rows.extend(self.emitter(EventType.START,opsets=opsets))
38+
rows.extend(
39+
self.emitter(
40+
EventType.START,opsets=opsets,ir_version=self.proto_.ir_version
41+
)
42+
)
3943
inputs=self.proto_.graph.input
4044
outputs=self.proto_.graph.output
4145
nodes=self.proto_.graph.node

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp