Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit8835156

Browse files
authored
Add class to yield results form onnx model and computes differences between two runs (#71)
* update requirements* Add class to yield results* black* add sumarry* add distance* text* compare function* fix FusedMatMul* fix alpha* example* documentation* fix length* doc
1 parent6ed1d1c commit8835156

File tree

15 files changed

+1243
-4
lines changed

15 files changed

+1243
-4
lines changed

‎CHANGELOGS.rst‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.2.0
55
+++++
66

7+
*:pr:`71`: adds tools to compare two onnx graphs
78
*:pr:`61`: adds function to plot onnx model as graphs
89
*:pr:`60`: supports translation of local functions
910
*:pr:`59`: add methods to update nodes in GraphAPI

‎_doc/api/reference.rst‎

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,33 @@ ExtendedReferenceEvaluator
55
++++++++++++++++++++++++++
66

77
..autoclass::onnx_array_api.reference.ExtendedReferenceEvaluator
8+
:members:
9+
10+
ResultType
11+
++++++++++
12+
13+
..autoclass::onnx_array_api.reference.ResultType
14+
:members:
15+
16+
ResultExecution
17+
+++++++++++++++
18+
19+
..autoclass::onnx_array_api.reference.ResultExecution
20+
:members:
21+
22+
YieldEvaluator
23+
++++++++++++++
24+
25+
..autoclass::onnx_array_api.reference.YieldEvaluator
26+
:members:
27+
28+
DistanceExecution
29+
+++++++++++++++++
30+
31+
..autoclass::onnx_array_api.reference.DistanceExecution
32+
:members:
33+
34+
compare_onnx_execution
35+
++++++++++++++++++++++
36+
37+
..autofunction::onnx_array_api.reference.compare_onnx_execution

‎_doc/command_lines.rst‎

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
=============
2+
command lines
3+
=============
4+
5+
compare
6+
=======
7+
8+
The function convers an onnx file into some code.
9+
10+
::
11+
12+
python -m compare -m1 model1.onnx -m2 model2.onnx -v 1
13+
14+
Output example::
15+
16+
[compare_onnx_execution] got 2 inputs
17+
[compare_onnx_execution] execute first model
18+
[compare_onnx_execution] got 5 results
19+
[compare_onnx_execution] execute second model
20+
[compare_onnx_execution] got 5 results
21+
[compare_onnx_execution] compute edit distance
22+
[compare_onnx_execution] got 4 pairs
23+
[compare_onnx_execution] done
24+
= | INPUT float32 5x6 AAAA X | INPUT float32 5x6 AAAA X
25+
= | INPUT float32 5x6 AAAA Y | INPUT float32 5x6 AAAA Y
26+
= | RESULT float32 5x6 AABB Add res | RESULT float32 5x6 AABB Add res
27+
= | RESULT float32 5x6 AAAA Cos Z | RESULT float32 5x6 AAAA Cos Z
28+
29+
..runpython::
30+
31+
from onnx_array_api._command_lines_parser import get_parser_compare
32+
get_parser_compare().print_help()
33+
34+
See function:func:`onnx_array_api.reference.compare_onnx_execution`.
35+
36+
translate
37+
=========
38+
39+
The function convers an onnx file into some code.
40+
41+
::
42+
43+
python -m translate ...
44+
45+
Output example::
46+
47+
not yet ready
48+
49+
..runpython::
50+
51+
from onnx_array_api._command_lines_parser import get_parser_translate
52+
get_parser_translate().print_help()

‎_doc/examples/plot_onnx_diff.py‎

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
"""
2+
3+
.. _l-onnx-diff-example:
4+
5+
Compares the conversions of the same model with different options
6+
=================================================================
7+
8+
The script compares two onnx models obtained with the same trained
9+
scikit-learn models but converted with different options.
10+
11+
A model
12+
+++++++
13+
"""
14+
15+
fromsklearn.mixtureimportGaussianMixture
16+
fromsklearn.datasetsimportload_iris
17+
fromsklearn.model_selectionimporttrain_test_split
18+
fromskl2onnximportto_onnx
19+
fromonnx_array_api.referenceimportcompare_onnx_execution
20+
fromonnx_array_api.plotting.text_plotimportonnx_simple_text_plot
21+
22+
23+
data=load_iris()
24+
X_train,X_test=train_test_split(data.data)
25+
model=GaussianMixture()
26+
model.fit(X_train)
27+
28+
#################################
29+
# Conversion to onnx
30+
# ++++++++++++++++++
31+
32+
onx=to_onnx(
33+
model,X_train[:1],options={id(model): {"score_samples":True}},target_opset=12
34+
)
35+
36+
print(onnx_simple_text_plot(onx))
37+
38+
##################################
39+
# Conversion to onnx without ReduceLogSumExp
40+
# ++++++++++++++++++++++++++++++++++++++++++
41+
42+
onx2=to_onnx(
43+
model,
44+
X_train[:1],
45+
options={id(model): {"score_samples":True}},
46+
black_op={"ReduceLogSumExp"},
47+
target_opset=12,
48+
)
49+
50+
print(onnx_simple_text_plot(onx2))
51+
52+
53+
#############################################
54+
# Differences
55+
# +++++++++++
56+
#
57+
# Function :func:`onnx_array_api.reference.compare_onnx_execution`
58+
# compares the intermediate results of two onnx models. Then it finds
59+
# the best alignmet between the two models using an edit distance.
60+
61+
res1,res2,align,dc=compare_onnx_execution(onx,onx2,verbose=1)
62+
print("------------")
63+
text=dc.to_str(res1,res2,align)
64+
print(text)
65+
66+
###############################
67+
# The display shows that ReduceSumSquare was replaced by Mul + ReduceSum,
68+
# and ReduceLogSumExp by ReduceMax + Sub + Exp + Log + Add.

‎_doc/index.rst‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ The objective is to speed up the implementation of converter libraries.
3636
tutorial/index
3737
api/index
3838
tech/index
39+
command_lines
3940
auto_examples/index
4041

4142
..toctree::

‎_doc/tutorial/index.rst‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ Tutorial
1010
graph_api
1111
light_api
1212
numpy_api
13+
tools
1314
benchmarks

‎_doc/tutorial/tools.rst‎

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
=====
2+
Tools
3+
=====
4+
5+
Some of useful tools.
6+
7+
Text representation
8+
===================
9+
10+
Plotting a graph is great but difficult to read when
11+
the graph is big and it is slow.
12+
:func:`onnx_array_api.plotting.text_plot.onnx_simple_text_plot`
13+
prints out a text representation.
14+
15+
Differences between two models
16+
==============================
17+
18+
How to understand the differences between two models
19+
assuming they are producing the same outputs?
20+
Example:ref:`l-onnx-diff-example` shows how to do it.

‎_unittests/ut_reference/test_array_tensor.py‎

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
importunittest
22
importnumpyasnp
33
fromonnximportTensorProto
4-
fromonnx.helperimportmake_graph,make_model,make_node,make_tensor_value_info
4+
fromonnx.helperimport (
5+
make_graph,
6+
make_model,
7+
make_node,
8+
make_tensor_value_info,
9+
make_opsetid,
10+
)
511
fromonnx_array_api.ext_test_caseimportExtTestCase
612
fromonnx_array_api.referenceimport (
713
to_array_extended,
@@ -51,6 +57,24 @@ def make_model_f8(fr, to):
5157
back=from_array_extended(got,"a")
5258
self.assertEqual(to,back.data_type)
5359

60+
deftest_fused_matmul(self):
61+
model=make_model(
62+
make_graph(
63+
[make_node("FusedMatMul", ["X","Y"], ["Z"],domain="com.microsoft")],
64+
"name",
65+
[
66+
make_tensor_value_info("X",TensorProto.FLOAT,None),
67+
make_tensor_value_info("Y",TensorProto.FLOAT,None),
68+
],
69+
[make_tensor_value_info("Z",TensorProto.FLOAT,None)],
70+
),
71+
opset_imports=[make_opsetid("",18),make_opsetid("com.microsoft",1)],
72+
)
73+
ref=ExtendedReferenceEvaluator(model)
74+
a=np.arange(4).reshape(-1,2)
75+
got=ref.run(None, {"X":a,"Y":a})
76+
self.assertEqualArray(a @a,got[0])
77+
5478

5579
if__name__=="__main__":
5680
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp