Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit718a0a5

Browse files
authored
Adds tools to compare models (#11)
* Adds tools to compare models* update path
1 parentac28cb9 commit718a0a5

File tree

16 files changed

+447
-8
lines changed

16 files changed

+447
-8
lines changed

‎.gitignore‎

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,7 @@ _doc/examples/plot_*.png
1515
_doc/_static/require.js
1616
_doc/_static/viz.js
1717
_unittests/ut__main/*.png
18+
_doc/examples/data/*.optimized.onnx
19+
_doc/examples/*.html
20+
_unittests/ut__main/_cache/*
21+
_unittests/ut__main/*.html

‎_doc/api/index.rst‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ API
1111
npx_jit
1212
npx_annot
1313
npx_numpy
14+
onnx_tools
1415
ort
1516
plotting
1617
tools
17-

‎_doc/api/onnx_tools.rst‎

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
.. _l-api-onnx-tools:
2+
3+
onnx tools
4+
==========
5+
6+
Differences
7+
+++++++++++
8+
9+
..autofunction::onnx_array_api.validation.diff.html_diff
10+
11+
..autofunction::onnx_array_api.validation.diff.text_diff
12+
13+
Protos
14+
++++++
15+
16+
..autofunction::onnx_array_api.validation.tools.randomize_proto

‎_doc/api/tools.rst‎

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@ Benchmark
88

99
..autofunction::onnx_array_api.ext_test_case.measure_time
1010

11+
Examples
12+
++++++++
13+
14+
..autofunction::onnx_array_api.ext_test_case.example_path
15+
1116
Profiling
1217
+++++++++
1318

@@ -25,5 +30,3 @@ Unit tests
2530

2631
..autoclass::onnx_array_api.ext_test_case.ExtTestCase
2732
:members:
28-
29-

‎_doc/examples/data/small.onnx‎

315 KB
Binary file not shown.

‎_doc/examples/plot_optimization.py‎

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
"""
2+
3+
.. _l-onnx-array-onnxruntime-optimization:
4+
5+
Optimization with onnxruntime
6+
=============================
7+
8+
9+
Optimize a model with onnxruntime
10+
+++++++++++++++++++++++++++++++++
11+
"""
12+
importos
13+
frompprintimportpprint
14+
importnumpy
15+
frompandasimportDataFrame
16+
importmatplotlib.pyplotasplt
17+
fromonnximportload
18+
fromonnx_array_api.ext_test_caseimportexample_path
19+
fromonnx_array_api.plotting.text_plotimportonnx_simple_text_plot
20+
fromonnx_array_api.validation.diffimporttext_diff,html_diff
21+
fromonnxruntimeimportGraphOptimizationLevel,InferenceSession,SessionOptions
22+
fromonnx_array_api.ext_test_caseimportmeasure_time
23+
fromonnx_array_api.ort.ort_optimizersimportort_optimized_model
24+
25+
26+
filename=example_path("data/small.onnx")
27+
optimized=filename+".optimized.onnx"
28+
29+
ifnotos.path.exists(optimized):
30+
ort_optimized_model(filename,output=optimized)
31+
print(optimized)
32+
33+
#############################
34+
# Output comparison
35+
# +++++++++++++++++
36+
37+
so=SessionOptions()
38+
so.graph_optimization_level=GraphOptimizationLevel.ORT_ENABLE_ALL
39+
img=numpy.random.random((1,3,112,112)).astype(numpy.float32)
40+
41+
sess=InferenceSession(filename,so)
42+
sess_opt=InferenceSession(optimized,so)
43+
input_name=sess.get_inputs()[0].name
44+
out=sess.run(None, {input_name:img})[0]
45+
out_opt=sess_opt.run(None, {input_name:img})[0]
46+
ifout.shape!=out_opt.shape:
47+
print("ERROR shape are different {out.shape} != {out_opt.shape}")
48+
diff=numpy.abs(out-out_opt).max()
49+
print(f"Differences:{diff}")
50+
51+
####################################
52+
# Difference
53+
# ++++++++++
54+
#
55+
# Unoptimized model.
56+
57+
withopen(filename,"rb")asf:
58+
model=load(f)
59+
print("first model to text...")
60+
text1=onnx_simple_text_plot(model,indent=False)
61+
print(text1)
62+
63+
#####################################
64+
# Optimized model.
65+
66+
67+
withopen(optimized,"rb")asf:
68+
model=load(f)
69+
print("second model to text...")
70+
text2=onnx_simple_text_plot(model,indent=False)
71+
print(text2)
72+
73+
########################################
74+
# Differences
75+
76+
print("differences...")
77+
print(text_diff(text1,text2))
78+
79+
#####################################
80+
# HTML version.
81+
82+
print("html differences...")
83+
output=html_diff(text1,text2)
84+
withopen("diff_html.html","w",encoding="utf-8")asf:
85+
f.write(output)
86+
print("done.")
87+
88+
#####################################
89+
# Benchmark
90+
# +++++++++
91+
92+
img=numpy.random.random((1,3,112,112)).astype(numpy.float32)
93+
94+
t1=measure_time(lambda:sess.run(None, {input_name:img}),repeat=25,number=25)
95+
t1["name"]="original"
96+
print("Original model")
97+
pprint(t1)
98+
99+
t2=measure_time(lambda:sess_opt.run(None, {input_name:img}),repeat=25,number=25)
100+
t2["name"]="optimized"
101+
print("Optimized")
102+
pprint(t2)
103+
104+
105+
############################
106+
# Plots
107+
# +++++
108+
109+
110+
fig,ax=plt.subplots(1,1,figsize=(12,4))
111+
112+
df=DataFrame([t1,t2]).set_index("name")
113+
print(df)
114+
115+
print(df["average"].values)
116+
print((df["average"]-df["deviation"]).values)
117+
118+
ax.bar(df.index,df["average"].values,yerr=df["deviation"].values,capsize=6)
119+
ax.set_title("Measure performance of optimized model\nlower is better")
120+
plt.grid()
121+
fig.savefig("plot_optimization.png")
315 KB
Binary file not shown.
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
importunittest
2+
fromonnximportload
3+
fromonnx.checkerimportcheck_model
4+
fromonnx_array_api.ext_test_caseimportExtTestCase
5+
fromonnx_array_api.ort.ort_optimizersimportort_optimized_model
6+
fromonnx_array_api.validation.diffimporttext_diff,html_diff
7+
8+
9+
classTestDiff(ExtTestCase):
10+
deftest_diff_optimized(self):
11+
data=self.relative_path(__file__,"data","small.onnx")
12+
withopen(data,"rb")asf:
13+
model=load(f)
14+
optimized=ort_optimized_model(model)
15+
check_model(optimized)
16+
diff=text_diff(model,optimized)
17+
self.assertIn("^^^^^^^^^^^^^^^^",diff)
18+
ht=html_diff(model,optimized)
19+
self.assertIn("<html><body>",ht)
20+
21+
22+
if__name__=="__main__":
23+
unittest.main(verbosity=2)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
importunittest
2+
fromonnximportload
3+
fromonnx.checkerimportcheck_model
4+
fromonnx_array_api.ext_test_caseimportExtTestCase
5+
fromonnx_array_api.validation.toolsimportrandomize_proto
6+
7+
8+
classTestTools(ExtTestCase):
9+
deftest_randomize_proto(self):
10+
data=self.relative_path(__file__,"data","small.onnx")
11+
withopen(data,"rb")asf:
12+
model=load(f)
13+
check_model(model)
14+
rnd=randomize_proto(model)
15+
self.assertEqual(len(model.SerializeToString()),len(rnd.SerializeToString()))
16+
check_model(rnd)
17+
18+
19+
if__name__=="__main__":
20+
unittest.main(verbosity=2)

‎onnx_array_api/ext_test_case.py‎

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
importos
12
importsys
23
importunittest
34
importwarnings
@@ -30,6 +31,20 @@ def call_f(self):
3031
returnwrapper
3132

3233

34+
defexample_path(path:str)->str:
35+
"""
36+
Fixes a path for the examples.
37+
Helps running the example within a unit test.
38+
"""
39+
ifos.path.exists(path):
40+
returnpath
41+
this=os.path.abspath(os.path.dirname(__file__))
42+
full=os.path.join(this,"..","_doc","examples",path)
43+
ifos.path.exists(full):
44+
returnfull
45+
raiseFileNotFoundError(f"Unable to find path{path!r} or{full!r}.")
46+
47+
3348
defmeasure_time(
3449
stmt:Callable,
3550
context:Optional[Dict[str,Any]]=None,
@@ -207,3 +222,18 @@ def capture(self, fct: Callable):
207222
withredirect_stderr(serr):
208223
res=fct()
209224
returnres,sout.getvalue(),serr.getvalue()
225+
226+
defrelative_path(self,filename:str,*names:List[str])->str:
227+
"""
228+
Returns a path relative to the folder *filename*
229+
is in. The function checks the path existence.
230+
231+
:param filename: filename
232+
:param names: additional path pieces
233+
:return: new path
234+
"""
235+
dir=os.path.abspath(os.path.dirname(filename))
236+
name=os.path.join(dir,*names)
237+
ifnotos.path.exists(name):
238+
raiseFileNotFoundError(f"Path{name!r} does not exists.")
239+
returnname

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp