Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit7895c27

Browse files
authored
Support translation of local functions (#60)
* add function to translate functions* doc* fix translation of local functions* refactoring* fix missing import* verbose* link
1 parent71aa3a0 commit7895c27

File tree

12 files changed

+492
-147
lines changed

12 files changed

+492
-147
lines changed

‎CHANGELOGS.rst‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.2.0
55
+++++
66

7+
*:pr:`60`: supports translation of local functions
78
*:pr:`59`: add methods to update nodes in GraphAPI
89

910
0.1.3

‎_doc/api/light_api.rst‎

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,13 @@ translate
1616

1717
..autofunction::onnx_array_api.light_api.translate
1818

19+
make_helper
20+
+++++++++++
21+
22+
..autofunction::onnx_array_api.light_api.make_helper.make_node_extended
23+
24+
..autofunction::onnx_array_api.light_api.make_helper.make_ref_attribute
25+
1926
Classes for the Light API
2027
=========================
2128

@@ -68,19 +75,13 @@ Classes for the Translater
6875
BaseEmitter
6976
+++++++++++
7077

71-
..autoclass::onnx_array_api.light_api.emitter.BaseEmitter
72-
:members:
73-
74-
Emitter
75-
+++++++
76-
77-
..autoclass::onnx_array_api.light_api.emitter.Emitter
78+
..autoclass::onnx_array_api.light_api.base_emitter.BaseEmitter
7879
:members:
7980

8081
EventType
8182
+++++++++
8283

83-
..autoclass::onnx_array_api.light_api.translate.EventType
84+
..autoclass::onnx_array_api.light_api.base_emitter.EventType
8485
:members:
8586

8687
InnerEmitter
@@ -89,6 +90,12 @@ InnerEmitter
8990
..autoclass::onnx_array_api.light_api.inner_emitter.InnerEmitter
9091
:members:
9192

93+
LightEmitter
94+
++++++++++++
95+
96+
..autoclass::onnx_array_api.light_api.light_emitter.LightEmitter
97+
:members:
98+
9299
Translater
93100
++++++++++
94101

Binary file not shown.

‎_unittests/ut_light_api/test_backend_export.py‎

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
importsys
12
importunittest
23
fromtypingimportAny,Dict,List,Optional
34
fromdifflibimportunified_diff
@@ -17,12 +18,16 @@
1718
make_opsetid,
1819
make_tensor_value_info,
1920
)
21+
fromonnx.reference.op_runimportto_array_extended
2022
fromonnx.numpy_helperimportfrom_array,to_array
2123
fromonnx.backend.baseimportDevice,DeviceType
2224
fromonnx_array_api.referenceimportExtendedReferenceEvaluator
25+
fromonnx_array_api.light_api.make_helperimportmake_node_extended
2326
fromonnx_array_api.light_apiimporttranslate
2427
fromonnx_array_api.plotting.text_plotimportonnx_simple_text_plot
2528

29+
verbosity=10if"-v"insys.argvor"--verbose"insys.argvelse0
30+
2631

2732
classReferenceImplementationError(RuntimeError):
2833
"Fails, export cannot be compared."
@@ -34,7 +39,7 @@ class ExportWrapper:
3439

3540
def__init__(self,model):
3641
self.model=model
37-
self.expected_sess=ExtendedReferenceEvaluator(self.model)
42+
self.expected_sess=ExtendedReferenceEvaluator(self.model,verbose=verbosity)
3843

3944
@property
4045
definput_names(self):
@@ -85,13 +90,15 @@ def run(
8590
locs= {
8691
"np":numpy,
8792
"to_array":to_array,
93+
"to_array_extended":to_array_extended,
8894
"from_array":from_array,
8995
"TensorProto":TensorProto,
9096
"make_function":make_function,
9197
"make_opsetid":make_opsetid,
9298
"make_model":make_model,
9399
"make_graph":make_graph,
94100
"make_node":make_node,
101+
"make_node_extended":make_node_extended,
95102
"make_tensor_value_info":make_tensor_value_info,
96103
}
97104
globs=locs.copy()
@@ -105,7 +112,7 @@ def run(
105112
f"Unable to executed code for api{api!r}\n{new_code}"
106113
)frome
107114
export_model=locs["model"]
108-
ref=ExtendedReferenceEvaluator(export_model)
115+
ref=ExtendedReferenceEvaluator(export_model,verbose=verbosity)
109116
try:
110117
got=ref.run(names,feeds)
111118
except (TypeError,AttributeError)ase:

‎_unittests/ut_light_api/test_translate.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
fromonnx.referenceimportReferenceEvaluator
77
fromonnx_array_api.ext_test_caseimportExtTestCase
88
fromonnx_array_api.light_apiimportstart,translate,g
9-
fromonnx_array_api.light_api.emitterimportEventType
9+
fromonnx_array_api.light_api.base_emitterimportEventType
1010

1111
OPSET_API=min(19,onnx_opset_version()-1)
1212

‎_unittests/ut_light_api/test_translate_classic.py‎

Lines changed: 116 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
fromonnximportModelProto,TensorProto,load
66
fromonnx.defsimportonnx_opset_version
77
fromonnx.referenceimportReferenceEvaluator
8+
fromonnx.reference.op_runimportOpRun
89
fromonnx.helperimport (
910
make_tensor_value_info,
1011
make_node,
@@ -68,7 +69,7 @@ def test_exp(self):
6869
functions = []
6970
inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
7071
nodes.append(
71-
make_node(
72+
make_node_extended(
7273
'Exp',
7374
['X'],
7475
['Y']
@@ -144,14 +145,14 @@ def test_transpose(self):
144145
)
145146
inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
146147
nodes.append(
147-
make_node(
148+
make_node_extended(
148149
'Reshape',
149150
['X', 'r'],
150151
['r0_0']
151152
)
152153
)
153154
nodes.append(
154-
make_node(
155+
make_node_extended(
155156
'Transpose',
156157
['r0_0'],
157158
['Y'],
@@ -210,7 +211,7 @@ def test_topk_reverse(self):
210211
inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
211212
inputs.append(make_tensor_value_info('K', TensorProto.INT64, shape=[]))
212213
nodes.append(
213-
make_node(
214+
make_node_extended(
214215
'TopK',
215216
['X', 'K'],
216217
['Values', 'Indices'],
@@ -264,7 +265,6 @@ def test_aionnxml(self):
264265
.to_onnx()
265266
)
266267
code=translate(onx,api="onnx")
267-
print(code)
268268
expected=dedent(
269269
"""
270270
opset_imports = [
@@ -285,14 +285,14 @@ def test_aionnxml(self):
285285
)
286286
inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
287287
nodes.append(
288-
make_node(
288+
make_node_extended(
289289
'Reshape',
290290
['X', 'r'],
291291
['USE']
292292
)
293293
)
294294
nodes.append(
295-
make_node(
295+
make_node_extended(
296296
'Normalizer',
297297
['USE'],
298298
['Y'],
@@ -318,7 +318,115 @@ def test_aionnxml(self):
318318
self.maxDiff=None
319319
self.assertEqual(expected,code)
320320

321+
@classmethod
322+
def_code_line(cls,code):
323+
lines=code.split("\n")
324+
return"\n".join(f"{i+1:03d}{line}"fori,lineinenumerate(lines))
325+
326+
@classmethod
327+
def_run(cls,code):
328+
try:
329+
code_compiled=compile(code,"<string>",mode="exec")
330+
exceptExceptionase:
331+
raiseAssertionError(
332+
f"Compilation failed due to{e}\n---\n{cls._code_line(code)}\n---\n{e}"
333+
)frome
334+
335+
importonnx
336+
importonnx.helper
337+
importonnx.numpy_helper
338+
importonnx_array_api.light_api.make_helper
339+
importonnx.reference.custom_element_types
340+
341+
deffrom_array_extended(tensor,name=None):
342+
dt=tensor.dtype
343+
if (
344+
dt==onnx.reference.custom_element_types.float8e4m3fn
345+
anddt.descr[0][0]=="e4m3fn"
346+
):
347+
to=TensorProto.FLOAT8E4M3FN
348+
dt_to=np.uint8
349+
elif (
350+
dt==onnx.reference.custom_element_types.bfloat16
351+
anddt.descr[0][0]=="bfloat16"
352+
):
353+
to=TensorProto.BFLOAT16
354+
dt_to=np.uint16
355+
else:
356+
returnonnx.numpy_helper.from_array(tensor,name)
357+
358+
t=onnx.numpy_helper.from_array(tensor.astype(dt_to),name)
359+
t.data_type=to
360+
returnt
361+
362+
globs=onnx.__dict__.copy()
363+
globs.update(onnx.helper.__dict__)
364+
globs.update(onnx.numpy_helper.__dict__)
365+
globs.update(onnx_array_api.light_api.make_helper.__dict__)
366+
globs.update(onnx.reference.custom_element_types.__dict__)
367+
globs["from_array_extended"]=from_array_extended
368+
locs= {}
369+
try:
370+
exec(code_compiled,globs,locs)
371+
exceptExceptionase:
372+
raiseAssertionError(
373+
f"Execution failed due to{e}\n---\n{cls._code_line(code)}\n---\n{e}"
374+
)frome
375+
returnglobs,locs
376+
377+
deftest_remove_nodes(self):
378+
path=os.path.join(
379+
os.path.dirname(__file__),"_data","custom_ops_type_inference_fails_0.onnx"
380+
)
381+
onx=load(path)
382+
code=translate(onx,api="onnx")
383+
_,locs=self._run(code)
384+
self.assertIn("model",locs)
385+
model=locs["model"]
386+
x=np.arange(4).reshape((-1,2)).astype(np.float32)
387+
feeds= {"X":x}
388+
389+
classCustomGemmFloat8E4M3FN(OpRun):
390+
op_domain="onnx_extented.ortops.tutorial.cpu"
391+
392+
def_run(
393+
self,
394+
x,
395+
y,
396+
bias=None,
397+
scale_x=None,
398+
scale_y=None,
399+
scale_z=None,
400+
transA=False,
401+
transB=False,
402+
dtype=None,
403+
rowMajor=None,
404+
computeType=None,
405+
):
406+
ifscale_xisnotNone:
407+
x=x*scale_x
408+
iftransA:
409+
x=x.T
410+
ifscale_yisnotNone:
411+
y=y*scale_y
412+
iftransB:
413+
y=y.T
414+
z=x @y
415+
ifbiasisnotNone:
416+
z+=bias
417+
ifscale_zisnotNone:
418+
z=z/scale_z
419+
return (z,)
420+
421+
ref=ReferenceEvaluator(onx,new_ops=[CustomGemmFloat8E4M3FN])
422+
expected=ref.run(None,feeds)[0]
423+
ref2=ReferenceEvaluator(model,new_ops=[CustomGemmFloat8E4M3FN])
424+
got=ref2.run(None,feeds)[0]
425+
self.assertEqualArray(expected,got)
426+
427+
# with open("debug_test_remove_nodes.py", "w") as f:
428+
# f.write(code)
429+
321430

322431
if__name__=="__main__":
323-
# TestLightApi().test_topk()
324432
unittest.main(verbosity=2)

‎onnx_array_api/light_api/__init__.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def translate(proto: ModelProto, single_line: bool = False, api: str = "light")
6767
:param single_line: as a single line or not
6868
:param api: API to export into,
6969
default is `"light"` and this is handle by class
70-
:class:`onnx_array_api.light_api.emitter.Emitter`,
70+
:class:`onnx_array_api.light_api.light_emitter.LightEmitter`,
7171
another value is `"onnx"` which is the inner API implemented
7272
in onnx package.
7373
:return: code

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp