Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit718a0a5

Browse files
authored
Adds tools to compare models (#11)
* Adds tools to compare models* update path
1 parentac28cb9 commit718a0a5

File tree

16 files changed

+447
-8
lines changed

16 files changed

+447
-8
lines changed

‎.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,7 @@ _doc/examples/plot_*.png
1515
_doc/_static/require.js
1616
_doc/_static/viz.js
1717
_unittests/ut__main/*.png
18+
_doc/examples/data/*.optimized.onnx
19+
_doc/examples/*.html
20+
_unittests/ut__main/_cache/*
21+
_unittests/ut__main/*.html

‎_doc/api/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ API
1111
npx_jit
1212
npx_annot
1313
npx_numpy
14+
onnx_tools
1415
ort
1516
plotting
1617
tools
17-

‎_doc/api/onnx_tools.rst

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
.. _l-api-onnx-tools:
2+
3+
onnx tools
4+
==========
5+
6+
Differences
7+
+++++++++++
8+
9+
..autofunction::onnx_array_api.validation.diff.html_diff
10+
11+
..autofunction::onnx_array_api.validation.diff.text_diff
12+
13+
Protos
14+
++++++
15+
16+
..autofunction::onnx_array_api.validation.tools.randomize_proto

‎_doc/api/tools.rst

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@ Benchmark
88

99
..autofunction::onnx_array_api.ext_test_case.measure_time
1010

11+
Examples
12+
++++++++
13+
14+
..autofunction::onnx_array_api.ext_test_case.example_path
15+
1116
Profiling
1217
+++++++++
1318

@@ -25,5 +30,3 @@ Unit tests
2530

2631
..autoclass::onnx_array_api.ext_test_case.ExtTestCase
2732
:members:
28-
29-

‎_doc/examples/data/small.onnx

315 KB
Binary file not shown.

‎_doc/examples/plot_optimization.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
"""
2+
3+
.. _l-onnx-array-onnxruntime-optimization:
4+
5+
Optimization with onnxruntime
6+
=============================
7+
8+
9+
Optimize a model with onnxruntime
10+
+++++++++++++++++++++++++++++++++
11+
"""
12+
importos
13+
frompprintimportpprint
14+
importnumpy
15+
frompandasimportDataFrame
16+
importmatplotlib.pyplotasplt
17+
fromonnximportload
18+
fromonnx_array_api.ext_test_caseimportexample_path
19+
fromonnx_array_api.plotting.text_plotimportonnx_simple_text_plot
20+
fromonnx_array_api.validation.diffimporttext_diff,html_diff
21+
fromonnxruntimeimportGraphOptimizationLevel,InferenceSession,SessionOptions
22+
fromonnx_array_api.ext_test_caseimportmeasure_time
23+
fromonnx_array_api.ort.ort_optimizersimportort_optimized_model
24+
25+
26+
filename=example_path("data/small.onnx")
27+
optimized=filename+".optimized.onnx"
28+
29+
ifnotos.path.exists(optimized):
30+
ort_optimized_model(filename,output=optimized)
31+
print(optimized)
32+
33+
#############################
34+
# Output comparison
35+
# +++++++++++++++++
36+
37+
so=SessionOptions()
38+
so.graph_optimization_level=GraphOptimizationLevel.ORT_ENABLE_ALL
39+
img=numpy.random.random((1,3,112,112)).astype(numpy.float32)
40+
41+
sess=InferenceSession(filename,so)
42+
sess_opt=InferenceSession(optimized,so)
43+
input_name=sess.get_inputs()[0].name
44+
out=sess.run(None, {input_name:img})[0]
45+
out_opt=sess_opt.run(None, {input_name:img})[0]
46+
ifout.shape!=out_opt.shape:
47+
print("ERROR shape are different {out.shape} != {out_opt.shape}")
48+
diff=numpy.abs(out-out_opt).max()
49+
print(f"Differences:{diff}")
50+
51+
####################################
52+
# Difference
53+
# ++++++++++
54+
#
55+
# Unoptimized model.
56+
57+
withopen(filename,"rb")asf:
58+
model=load(f)
59+
print("first model to text...")
60+
text1=onnx_simple_text_plot(model,indent=False)
61+
print(text1)
62+
63+
#####################################
64+
# Optimized model.
65+
66+
67+
withopen(optimized,"rb")asf:
68+
model=load(f)
69+
print("second model to text...")
70+
text2=onnx_simple_text_plot(model,indent=False)
71+
print(text2)
72+
73+
########################################
74+
# Differences
75+
76+
print("differences...")
77+
print(text_diff(text1,text2))
78+
79+
#####################################
80+
# HTML version.
81+
82+
print("html differences...")
83+
output=html_diff(text1,text2)
84+
withopen("diff_html.html","w",encoding="utf-8")asf:
85+
f.write(output)
86+
print("done.")
87+
88+
#####################################
89+
# Benchmark
90+
# +++++++++
91+
92+
img=numpy.random.random((1,3,112,112)).astype(numpy.float32)
93+
94+
t1=measure_time(lambda:sess.run(None, {input_name:img}),repeat=25,number=25)
95+
t1["name"]="original"
96+
print("Original model")
97+
pprint(t1)
98+
99+
t2=measure_time(lambda:sess_opt.run(None, {input_name:img}),repeat=25,number=25)
100+
t2["name"]="optimized"
101+
print("Optimized")
102+
pprint(t2)
103+
104+
105+
############################
106+
# Plots
107+
# +++++
108+
109+
110+
fig,ax=plt.subplots(1,1,figsize=(12,4))
111+
112+
df=DataFrame([t1,t2]).set_index("name")
113+
print(df)
114+
115+
print(df["average"].values)
116+
print((df["average"]-df["deviation"]).values)
117+
118+
ax.bar(df.index,df["average"].values,yerr=df["deviation"].values,capsize=6)
119+
ax.set_title("Measure performance of optimized model\nlower is better")
120+
plt.grid()
121+
fig.savefig("plot_optimization.png")
315 KB
Binary file not shown.

‎_unittests/ut_validation/test_diff.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
importunittest
2+
fromonnximportload
3+
fromonnx.checkerimportcheck_model
4+
fromonnx_array_api.ext_test_caseimportExtTestCase
5+
fromonnx_array_api.ort.ort_optimizersimportort_optimized_model
6+
fromonnx_array_api.validation.diffimporttext_diff,html_diff
7+
8+
9+
classTestDiff(ExtTestCase):
10+
deftest_diff_optimized(self):
11+
data=self.relative_path(__file__,"data","small.onnx")
12+
withopen(data,"rb")asf:
13+
model=load(f)
14+
optimized=ort_optimized_model(model)
15+
check_model(optimized)
16+
diff=text_diff(model,optimized)
17+
self.assertIn("^^^^^^^^^^^^^^^^",diff)
18+
ht=html_diff(model,optimized)
19+
self.assertIn("<html><body>",ht)
20+
21+
22+
if__name__=="__main__":
23+
unittest.main(verbosity=2)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
importunittest
2+
fromonnximportload
3+
fromonnx.checkerimportcheck_model
4+
fromonnx_array_api.ext_test_caseimportExtTestCase
5+
fromonnx_array_api.validation.toolsimportrandomize_proto
6+
7+
8+
classTestTools(ExtTestCase):
9+
deftest_randomize_proto(self):
10+
data=self.relative_path(__file__,"data","small.onnx")
11+
withopen(data,"rb")asf:
12+
model=load(f)
13+
check_model(model)
14+
rnd=randomize_proto(model)
15+
self.assertEqual(len(model.SerializeToString()),len(rnd.SerializeToString()))
16+
check_model(rnd)
17+
18+
19+
if__name__=="__main__":
20+
unittest.main(verbosity=2)

‎onnx_array_api/ext_test_case.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
importos
12
importsys
23
importunittest
34
importwarnings
@@ -30,6 +31,20 @@ def call_f(self):
3031
returnwrapper
3132

3233

34+
defexample_path(path:str)->str:
35+
"""
36+
Fixes a path for the examples.
37+
Helps running the example within a unit test.
38+
"""
39+
ifos.path.exists(path):
40+
returnpath
41+
this=os.path.abspath(os.path.dirname(__file__))
42+
full=os.path.join(this,"..","_doc","examples",path)
43+
ifos.path.exists(full):
44+
returnfull
45+
raiseFileNotFoundError(f"Unable to find path{path!r} or{full!r}.")
46+
47+
3348
defmeasure_time(
3449
stmt:Callable,
3550
context:Optional[Dict[str,Any]]=None,
@@ -207,3 +222,18 @@ def capture(self, fct: Callable):
207222
withredirect_stderr(serr):
208223
res=fct()
209224
returnres,sout.getvalue(),serr.getvalue()
225+
226+
defrelative_path(self,filename:str,*names:List[str])->str:
227+
"""
228+
Returns a path relative to the folder *filename*
229+
is in. The function checks the path existence.
230+
231+
:param filename: filename
232+
:param names: additional path pieces
233+
:return: new path
234+
"""
235+
dir=os.path.abspath(os.path.dirname(filename))
236+
name=os.path.join(dir,*names)
237+
ifnotos.path.exists(name):
238+
raiseFileNotFoundError(f"Path{name!r} does not exists.")
239+
returnname

‎onnx_array_api/ort/ort_optimizers.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
fromtypingimportUnion
1+
fromtypingimportUnion,Optional
22
fromonnximportModelProto,load
33
fromonnxruntimeimportInferenceSession,SessionOptions
44
fromonnxruntime.capi._pybind_stateimportGraphOptimizationLevel
55
from ..cacheimportget_cache_file
66

77

88
defort_optimized_model(
9-
onx:Union[str,ModelProto],level:str="ORT_ENABLE_ALL"
9+
onx:Union[str,ModelProto],
10+
level:str="ORT_ENABLE_ALL",
11+
output:Optional[str]=None,
1012
)->Union[str,ModelProto]:
1113
"""
1214
Returns the optimized model used by onnxruntime before
@@ -15,6 +17,7 @@ def ort_optimized_model(
1517
:param onx: ModelProto
1618
:param level: optimization level, `'ORT_ENABLE_BASIC'`,
1719
`'ORT_ENABLE_EXTENDED'`, `'ORT_ENABLE_ALL'`
20+
:param output: output file if the proposed cache is not wanted
1821
:return: optimized model
1922
"""
2023
glevel=getattr(GraphOptimizationLevel,level,None)
@@ -23,13 +26,18 @@ def ort_optimized_model(
2326
f"Unrecognized level{level!r} among{dir(GraphOptimizationLevel)}."
2427
)
2528

26-
cache=get_cache_file("ort_optimized_model.onnx",remove=True)
29+
ifoutputisnotNone:
30+
cache=output
31+
else:
32+
cache=get_cache_file("ort_optimized_model.onnx",remove=True)
2733
so=SessionOptions()
2834
so.graph_optimization_level=glevel
2935
so.optimized_model_filepath=str(cache)
3036
InferenceSession(onxifisinstance(onx,str)elseonx.SerializeToString(),so)
31-
ifnotcache.exists():
37+
ifoutputisNoneandnotcache.exists():
3238
raiseRuntimeError(f"The optimized model{str(cache)!r} not found.")
39+
ifoutputisnotNone:
40+
returnoutput
3341
ifisinstance(onx,str):
3442
returnstr(cache)
3543
opt_onx=load(str(cache))

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp