Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit4b5934c

Browse files
committed
fix translation of local functions
1 parentd6acd35 commit4b5934c

File tree

7 files changed

+464
-236
lines changed

7 files changed

+464
-236
lines changed

‎_doc/api/light_api.rst‎

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,13 @@ translate
1616

1717
..autofunction::onnx_array_api.light_api.translate
1818

19+
make_helper
20+
+++++++++++
21+
22+
..autofunction::onnx_array_api.light_api.make_helper.make_node_extended
23+
24+
..autofunction::onnx_array_api.light_api.make_helper.make_ref_attribute
25+
1926
Classes for the Light API
2027
=========================
2128

@@ -68,7 +75,7 @@ Classes for the Translater
6875
BaseEmitter
6976
+++++++++++
7077

71-
..autoclass::onnx_array_api.light_api.emitter.BaseEmitter
78+
..autoclass::onnx_array_api.light_api.base_emitter.BaseEmitter
7279
:members:
7380

7481
Emitter
@@ -80,7 +87,7 @@ Emitter
8087
EventType
8188
+++++++++
8289

83-
..autoclass::onnx_array_api.light_api.translate.EventType
90+
..autoclass::onnx_array_api.light_api.base_emitter.EventType
8491
:members:
8592

8693
InnerEmitter

‎_unittests/ut_light_api/test_translate_classic.py‎

Lines changed: 110 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
fromonnximportModelProto,TensorProto,load
66
fromonnx.defsimportonnx_opset_version
77
fromonnx.referenceimportReferenceEvaluator
8+
fromonnx.reference.op_runimportOpRun
89
fromonnx.helperimport (
910
make_tensor_value_info,
1011
make_node,
@@ -68,7 +69,7 @@ def test_exp(self):
6869
functions = []
6970
inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
7071
nodes.append(
71-
make_node(
72+
make_node_extended(
7273
'Exp',
7374
['X'],
7475
['Y']
@@ -144,14 +145,14 @@ def test_transpose(self):
144145
)
145146
inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
146147
nodes.append(
147-
make_node(
148+
make_node_extended(
148149
'Reshape',
149150
['X', 'r'],
150151
['r0_0']
151152
)
152153
)
153154
nodes.append(
154-
make_node(
155+
make_node_extended(
155156
'Transpose',
156157
['r0_0'],
157158
['Y'],
@@ -210,7 +211,7 @@ def test_topk_reverse(self):
210211
inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
211212
inputs.append(make_tensor_value_info('K', TensorProto.INT64, shape=[]))
212213
nodes.append(
213-
make_node(
214+
make_node_extended(
214215
'TopK',
215216
['X', 'K'],
216217
['Values', 'Indices'],
@@ -284,14 +285,14 @@ def test_aionnxml(self):
284285
)
285286
inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
286287
nodes.append(
287-
make_node(
288+
make_node_extended(
288289
'Reshape',
289290
['X', 'r'],
290291
['USE']
291292
)
292293
)
293294
nodes.append(
294-
make_node(
295+
make_node_extended(
295296
'Normalizer',
296297
['USE'],
297298
['Y'],
@@ -317,16 +318,115 @@ def test_aionnxml(self):
317318
self.maxDiff=None
318319
self.assertEqual(expected,code)
319320

321+
@classmethod
322+
def_code_line(cls,code):
323+
lines=code.split("\n")
324+
return"\n".join(f"{i+1:03d}{line}"fori,lineinenumerate(lines))
325+
326+
@classmethod
327+
def_run(cls,code):
328+
try:
329+
code_compiled=compile(code,"<string>",mode="exec")
330+
exceptExceptionase:
331+
raiseAssertionError(
332+
f"Compilation failed due to{e}\n---\n{cls._code_line(code)}\n---\n{e}"
333+
)frome
334+
335+
importonnx
336+
importonnx.helper
337+
importonnx.numpy_helper
338+
importonnx_array_api.light_api.make_helper
339+
importonnx.reference.custom_element_types
340+
341+
deffrom_array_extended(tensor,name=None):
342+
dt=tensor.dtype
343+
if (
344+
dt==onnx.reference.custom_element_types.float8e4m3fn
345+
anddt.descr[0][0]=="e4m3fn"
346+
):
347+
to=TensorProto.FLOAT8E4M3FN
348+
dt_to=np.uint8
349+
elif (
350+
dt==onnx.reference.custom_element_types.bfloat16
351+
anddt.descr[0][0]=="bfloat16"
352+
):
353+
to=TensorProto.BFLOAT16
354+
dt_to=np.uint16
355+
else:
356+
returnonnx.numpy_helper.from_array(tensor,name)
357+
358+
t=onnx.numpy_helper.from_array(tensor.astype(dt_to),name)
359+
t.data_type=to
360+
returnt
361+
362+
globs=onnx.__dict__.copy()
363+
globs.update(onnx.helper.__dict__)
364+
globs.update(onnx.numpy_helper.__dict__)
365+
globs.update(onnx_array_api.light_api.make_helper.__dict__)
366+
globs.update(onnx.reference.custom_element_types.__dict__)
367+
globs["from_array_extended"]=from_array_extended
368+
locs= {}
369+
try:
370+
exec(code_compiled,globs,locs)
371+
exceptExceptionase:
372+
raiseAssertionError(
373+
f"Execution failed due to{e}\n---\n{cls._code_line(code)}\n---\n{e}"
374+
)frome
375+
returnglobs,locs
376+
320377
deftest_remove_nodes(self):
321378
path=os.path.join(
322379
os.path.dirname(__file__),"_data","custom_ops_type_inference_fails_0.onnx"
323380
)
324381
onx=load(path)
325-
text=translate(onx,api="onnx")
326-
withopen("debug_test_remove_nodes.py","w")asf:
327-
f.write(text)
382+
code=translate(onx,api="onnx")
383+
_,locs=self._run(code)
384+
self.assertIn("model",locs)
385+
model=locs["model"]
386+
x=np.arange(4).reshape((-1,2)).astype(np.float32)
387+
feeds= {"X":x}
388+
389+
classCustomGemmFloat8E4M3FN(OpRun):
390+
op_domain="onnx_extented.ortops.tutorial.cpu"
391+
392+
def_run(
393+
self,
394+
x,
395+
y,
396+
bias=None,
397+
scale_x=None,
398+
scale_y=None,
399+
scale_z=None,
400+
transA=False,
401+
transB=False,
402+
dtype=None,
403+
rowMajor=None,
404+
computeType=None,
405+
):
406+
ifscale_xisnotNone:
407+
x=x*scale_x
408+
iftransA:
409+
x=x.T
410+
ifscale_yisnotNone:
411+
y=y*scale_y
412+
iftransB:
413+
y=y.T
414+
z=x @y
415+
ifbiasisnotNone:
416+
z+=bias
417+
ifscale_zisnotNone:
418+
z=z/scale_z
419+
return (z,)
420+
421+
ref=ReferenceEvaluator(onx,new_ops=[CustomGemmFloat8E4M3FN])
422+
expected=ref.run(None,feeds)[0]
423+
ref2=ReferenceEvaluator(model,new_ops=[CustomGemmFloat8E4M3FN])
424+
got=ref2.run(None,feeds)[0]
425+
self.assertEqualArray(expected,got)
426+
427+
# with open("debug_test_remove_nodes.py", "w") as f:
428+
# f.write(code)
328429

329430

330431
if__name__=="__main__":
331-
# TestLightApi().test_topk()
332432
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp