Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit4cf9dcc

Browse files
authored
Adds a mode to compare models without execution (#76)
* update requirements* Add a mode to compare model without execution* changelogs* improve initializer* fix display* fix side
1 parent7675869 commit4cf9dcc

File tree

4 files changed

+255
-38
lines changed

4 files changed

+255
-38
lines changed

‎CHANGELOGS.rst‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.2.0
55
+++++
66

7+
*:pr:`76`: add a mode to compare models without execution
78
*:pr:`75`: add QuickGelu to ExtendedReferenceEvaluator
89
*:pr:`71`: adds tools to compare two onnx graphs
910
*:pr:`61`: adds function to plot onnx model as graphs

‎_unittests/ut_reference/test_evaluator_yield.py‎

Lines changed: 71 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
importunittest
22
importnumpyasnp
33
fromonnximportTensorProto
4+
fromonnx.checkerimportcheck_model
45
fromonnx.helperimport (
56
make_function,
67
make_graph,
@@ -9,6 +10,7 @@
910
make_opsetid,
1011
make_tensor_value_info,
1112
)
13+
fromonnx.numpy_helperimportfrom_array
1214
fromonnx.parserimportparse_model
1315
fromonnx_array_api.ext_test_caseimportExtTestCase
1416
fromonnx_array_api.referenceimport (
@@ -422,13 +424,13 @@ def test_distance_sequence_str(self):
422424
text=dc.to_str(s1,s2,align)
423425
self.assertIn("OUTPUT",text)
424426
expected="""
425-
001=|INPUTfloat322x2ABCDA|INPUTfloat322x2ABCDA
426-
002=|INPUTfloat322x2ABCDB|INPUTfloat322x2ABCDB
427-
003~|INPUTfloat322x3ABCDX|INPUTfloat322x2ABCDX
428-
004-|RESULTfloat322x2CEIOExpH|
429-
005=|RESULTfloat322x2CEIOLinearRegrY1|RESULTfloat322x2CEIOLinearRegrY1
430-
006~|RESULTfloat322x2CEIOAbsY|RESULTfloat322x3CEIPAbsZ
431-
007~|OUTPUTfloat322x2CEIOY|OUTPUTfloat322x2CEIPY
427+
001=|INPUTfloat322:2x2ABCDA|INPUTfloat322:2x2ABCDA
428+
002=|INPUTfloat322:2x2ABCDB|INPUTfloat322:2x2ABCDB
429+
003~|INPUTfloat322:2x3ABCDX|INPUTfloat322:2x2ABCDX
430+
004-|RESULTfloat322:2x2CEIOExpH|
431+
005=|RESULTfloat322:2x2CEIOLinearRegressioY1|RESULTfloat322:2x2CEIOLinearRegressioY1
432+
006~|RESULTfloat322:2x2CEIOAbsY|RESULTfloat322:2x3CEIPAbsZ
433+
007~|OUTPUTfloat322:2x2CEIOY|OUTPUTfloat322:2x2CEIPY
432434
""".replace(
433435
" ",""
434436
).strip(
@@ -460,6 +462,68 @@ def test_compare_execution(self):
460462
self.assertIn("CAAA Constant",text)
461463
self.assertEqual(len(align),5)
462464

465+
deftest_no_execution(self):
466+
model=make_model(
467+
make_graph(
468+
[
469+
make_node("Unsqueeze", ["X","zero"], ["xu1"]),
470+
make_node("Unsqueeze", ["xu1","un"], ["xu2"]),
471+
make_node("Reshape", ["xu2","shape1"], ["xm1"]),
472+
make_node("Reshape", ["Y","shape2"], ["xm2c"]),
473+
make_node("Cast", ["xm2c"], ["xm2"],to=1),
474+
make_node("MatMul", ["xm1","xm2"], ["xm"]),
475+
make_node("Reshape", ["xm","shape3"], ["Z"]),
476+
],
477+
"dummy",
478+
[
479+
make_tensor_value_info("X",TensorProto.FLOAT, [32,128]),
480+
make_tensor_value_info("Y",TensorProto.FLOAT, [3,5,128,64]),
481+
],
482+
[make_tensor_value_info("Z",TensorProto.FLOAT, [3,5,32,"N"])],
483+
[
484+
from_array(np.array([0],dtype=np.int64),name="zero"),
485+
from_array(np.array([1],dtype=np.int64),name="un"),
486+
from_array(np.array([1,32,128],dtype=np.int64),name="shape1"),
487+
from_array(np.array([15,128,64],dtype=np.int64),name="shape2"),
488+
from_array(np.array([3,5,32,64],dtype=np.int64),name="shape3"),
489+
],
490+
)
491+
)
492+
check_model(model)
493+
res1,res2,align,dc=compare_onnx_execution(model,model,mode="nodes")
494+
text=dc.to_str(res1,res2,align)
495+
self.assertIn("012 = | NODE",text)
496+
497+
model2=make_model(
498+
make_graph(
499+
[
500+
make_node("Unsqueeze", ["X","zero"], ["xu1"]),
501+
make_node("Unsqueeze", ["xu1","un"], ["xu2"]),
502+
make_node("Reshape", ["xu2","shape1"], ["xm1"]),
503+
make_node("Reshape", ["Y","shape2"], ["xm2c"]),
504+
make_node("MatMul", ["xm1","xm2c"], ["xm"]),
505+
make_node("Reshape", ["xm","shape3"], ["Z"]),
506+
],
507+
"dummy",
508+
[
509+
make_tensor_value_info("X",TensorProto.FLOAT, [32,128]),
510+
make_tensor_value_info("Y",TensorProto.FLOAT, [3,5,128,64]),
511+
],
512+
[make_tensor_value_info("Z",TensorProto.FLOAT, [3,5,32,"N"])],
513+
[
514+
from_array(np.array([0],dtype=np.int64),name="zero"),
515+
from_array(np.array([1],dtype=np.int64),name="un"),
516+
from_array(np.array([1,32,128],dtype=np.int64),name="shape1"),
517+
from_array(np.array([15,128,64],dtype=np.int64),name="shape2"),
518+
from_array(np.array([3,5,32,64],dtype=np.int64),name="shape3"),
519+
],
520+
)
521+
)
522+
check_model(model2)
523+
res1,res2,align,dc=compare_onnx_execution(model,model2,mode="nodes")
524+
text=dc.to_str(res1,res2,align)
525+
self.assertIn("012 = | NODE",text)
526+
463527

464528
if__name__=="__main__":
465529
unittest.main(verbosity=2)

‎onnx_array_api/_command_lines_parser.py‎

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def get_main_parser() -> ArgumentParser:
2020
Selects a command.
2121
2222
'translate' exports an onnx graph into a piece of code replicating it,
23-
'compares' compares the execution of two onnx models
23+
'compare' compares the execution of two onnx models
2424
"""
2525
),
2626
)
@@ -90,6 +90,13 @@ def get_parser_compare() -> ArgumentParser:
9090
required=True,
9191
help="second onnx model",
9292
)
93+
parser.add_argument(
94+
"-m",
95+
"--mode",
96+
choices=["execute","nodes"],
97+
default="execute",
98+
help="compare the execution ('execute') or the nodes only ('nodes')",
99+
)
93100
parser.add_argument(
94101
"-v",
95102
"--verbose",
@@ -112,8 +119,10 @@ def _cmd_compare(argv: List[Any]):
112119
args=parser.parse_args(argv[1:])
113120
onx1=onnx.load(args.model1)
114121
onx2=onnx.load(args.model2)
115-
res1,res2,align,dc=compare_onnx_execution(onx1,onx2,verbose=args.verbose)
116-
text=dc.to_str(res1,res2,align,column_size=args.column_size)
122+
res1,res2,align,dc=compare_onnx_execution(
123+
onx1,onx2,verbose=args.verbose,mode=args.mode
124+
)
125+
text=dc.to_str(res1,res2,align,column_size=int(args.column_size))
117126
print(text)
118127

119128

@@ -127,7 +136,7 @@ def main(argv: Optional[List[Any]] = None):
127136
parser=get_main_parser()
128137
parser.parse_args(argv)
129138
else:
130-
parsers=dict(translate=get_parser_translate)
139+
parsers=dict(translate=get_parser_translate,compare=get_parser_compare)
131140
cmd=argv[0]
132141
ifcmdnotinparsers:
133142
raiseValueError(

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp