Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitdc61d97

Browse files
authored
Adds function to profile onnxruntime (#12)
* Adds function to profile onnxruntime* examples* fix example* fix requirements.txt* fix example* documentation* update examples* remove a warning* fix providers
1 parent718a0a5 commitdc61d97

File tree

11 files changed

+387
-10
lines changed

11 files changed

+387
-10
lines changed

‎_doc/api/ort.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@
33
ort
44
===
55

6+
Optimization
7+
++++++++++++
8+
9+
..autofunction::onnx_array_api.ort.ort_optimizers.ort_optimized_model
10+
611
OrtTensor
712
+++++++++
813

@@ -15,3 +20,8 @@ OrtTensor
1520
..autoclass::onnx_array_api.ort.ort_tensors.OrtTensor
1621
:members:
1722

23+
Profiling
24+
+++++++++
25+
26+
..autofunction::onnx_array_api.ort.ort_profile.ort_profile
27+

‎_doc/examples/plot_benchmark_rf.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,13 +200,13 @@ def measure_inference(fct, X, repeat, max_time=5, quantile=1):
200200
cache_dir,f"nf-{X.shape[1]}-rf-J-{n_j}-E-{n_estimators}-D-{max_depth}.onnx"
201201
)
202202
ifos.path.exists(cache_name):
203-
sess=InferenceSession(cache_name,so)
203+
sess=InferenceSession(cache_name,so,providers=["CPUExecutionProvider"])
204204
else:
205205
bar.set_description(f"J={n_j} E={n_estimators} D={max_depth} cvt onnx")
206206
onx=to_onnx(rf,X[:1])
207207
withopen(cache_name,"wb")asf:
208208
f.write(onx.SerializeToString())
209-
sess=InferenceSession(cache_name,so)
209+
sess=InferenceSession(cache_name,so,providers=["CPUExecutionProvider"])
210210
onx_size=os.stat(cache_name).st_size
211211

212212
# run once to avoid counting the first run
@@ -234,7 +234,7 @@ def measure_inference(fct, X, repeat, max_time=5, quantile=1):
234234
o1.update(dict(avg=mean,med=med,n_runs=r,ttime=t,name="base"))
235235
data.append(o1)
236236

237-
#baseline
237+
#onnxruntime
238238
bar.set_description(f"J={n_j} E={n_estimators} D={max_depth} predictO")
239239
r,t,mean,med=measure_inference(
240240
lambdax:sess.run(None, {"X":x}),X,repeat=repeat,max_time=max_time
@@ -258,7 +258,7 @@ def measure_inference(fct, X, repeat, max_time=5, quantile=1):
258258

259259
#######################################################
260260
# Printing the data
261-
print(df)
261+
df
262262

263263
#####################################################
264264
# Plot

‎_doc/examples/plot_optimization.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@
55
Optimization with onnxruntime
66
=============================
77
8+
*onnxruntime* optimizes the onnx graph by default before running
9+
the inference. It modifies, fuses or add new operators.
10+
Some of them are standard onnx operators, some of them
11+
are implemented in onnxruntime (see `Supported Operators
12+
<https://github.com/microsoft/onnxruntime/blob/main/docs/OperatorKernels.md>`_).
13+
This example looks into the differences of two models.
814
915
Optimize a model with onnxruntime
1016
+++++++++++++++++++++++++++++++++
@@ -38,8 +44,8 @@
3844
so.graph_optimization_level=GraphOptimizationLevel.ORT_ENABLE_ALL
3945
img=numpy.random.random((1,3,112,112)).astype(numpy.float32)
4046

41-
sess=InferenceSession(filename,so)
42-
sess_opt=InferenceSession(optimized,so)
47+
sess=InferenceSession(filename,so,providers=["CPUExecutionProvider"])
48+
sess_opt=InferenceSession(optimized,so,providers=["CPUExecutionProvider"])
4349
input_name=sess.get_inputs()[0].name
4450
out=sess.run(None, {input_name:img})[0]
4551
out_opt=sess_opt.run(None, {input_name:img})[0]
@@ -110,10 +116,10 @@
110116
fig,ax=plt.subplots(1,1,figsize=(12,4))
111117

112118
df=DataFrame([t1,t2]).set_index("name")
113-
print(df)
119+
df
114120

115-
print(df["average"].values)
116-
print((df["average"]-df["deviation"]).values)
121+
#######################################
122+
# And the graph is:
117123

118124
ax.bar(df.index,df["average"].values,yerr=df["deviation"].values,capsize=6)
119125
ax.set_title("Measure performance of optimized model\nlower is better")

‎_doc/examples/plot_profiling.py

Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
"""
2+
3+
.. _l-onnx-array-onnxruntime-profiling:
4+
5+
Profiling with onnxruntime
6+
==========================
7+
8+
*onnxruntime* optimizes the onnx graph by default before running
9+
the inference. It modifies, fuses or add new operators.
10+
Some of them are standard onnx operators, some of them
11+
are implemented in onnxruntime (see `Supported Operators
12+
<https://github.com/microsoft/onnxruntime/blob/main/docs/OperatorKernels.md>`_).
13+
This example profiles the two models.
14+
15+
Optimize a model with onnxruntime
16+
+++++++++++++++++++++++++++++++++
17+
"""
18+
importos
19+
importnumpy
20+
importmatplotlib.pyplotasplt
21+
fromonnxruntimeimportget_available_providers
22+
fromonnx_array_api.ext_test_caseimportexample_path
23+
fromonnx_array_api.ort.ort_optimizersimportort_optimized_model
24+
fromonnx_array_api.ort.ort_profileimportort_profile
25+
26+
27+
filename=example_path("data/small.onnx")
28+
optimized=filename+".optimized.onnx"
29+
30+
ifnotos.path.exists(optimized):
31+
ort_optimized_model(filename,output=optimized)
32+
print(optimized)
33+
34+
#############################
35+
# Profiling
36+
# +++++++++
37+
38+
feeds= {"input":numpy.random.random((1,3,112,112)).astype(numpy.float32)}
39+
prof_base=ort_profile(
40+
filename,
41+
feeds,
42+
repeat=6,
43+
disable_optimization=True,
44+
providers=["CPUExecutionProvider"],
45+
)
46+
prof_base.to_excel("prof_base.xlsx",index=False)
47+
prof_base
48+
49+
#######################################
50+
# And the optimized model.
51+
52+
prof_opt=ort_profile(
53+
optimized,
54+
feeds,
55+
repeat=6,
56+
disable_optimization=True,
57+
providers=["CPUExecutionProvider"],
58+
)
59+
prof_opt
60+
61+
#######################################
62+
# And the graph is:
63+
64+
65+
defplot_profile(df,ax0,ax1=None,title=None):
66+
gr_dur= (
67+
df[["dur","args_op_name"]].groupby("args_op_name").sum().sort_values("dur")
68+
)
69+
gr_dur.plot.barh(ax=ax0)
70+
iftitleisnotNone:
71+
ax0.set_title(title)
72+
ifax1isnotNone:
73+
gr_n= (
74+
df[["dur","args_op_name"]]
75+
.groupby("args_op_name")
76+
.count()
77+
.sort_values("dur")
78+
)
79+
gr_n=gr_n.loc[gr_dur.index, :]
80+
gr_n.plot.barh(ax=ax1)
81+
ax1.set_title("n occurences")
82+
83+
84+
unique_op=set(prof_base["args_op_name"])
85+
fig,ax=plt.subplots(2,2,figsize=(10,len(unique_op)),sharex="col")
86+
plot_profile(prof_base,ax[0,0],ax[0,1],title="baseline")
87+
plot_profile(prof_opt,ax[1,0],ax[1,1],title="optimized")
88+
89+
fig.savefig("plot_profiling.png")
90+
91+
##################################################
92+
# Merging profiles
93+
# ++++++++++++++++
94+
#
95+
# Let's try to compare both profiles assuming every iteration
96+
# process the same image and the input and output size are the
97+
# same at every iteration.
98+
99+
100+
defpreprocess(df):
101+
groupkey= [
102+
"args_op_name",
103+
"args_output_type_shape",
104+
"args_input_type_shape",
105+
"args_provider",
106+
]
107+
108+
def_idx(row):
109+
"""
110+
There may be multiple node with the same
111+
input/output types and shapes.
112+
This function gives every instance a distinct id.
113+
First unique op with same I/O receives the index 0.
114+
The counter restart when the session goes to the
115+
next image.
116+
"""
117+
ifrow["cat"]=="Session":
118+
occurences[0]= {}
119+
return-1
120+
assert"idx"notingroupkey
121+
vals= [row[k]forkingroupkey]
122+
key=tuple(map(str,vals))
123+
ifkeynotinoccurences[0]:
124+
occurences[0][key]=0
125+
else:
126+
occurences[0][key]+=1
127+
returnoccurences[0][key]
128+
129+
df=df.copy()
130+
occurences= [{}]
131+
df["idx"]=df.apply(_idx,axis=1)
132+
df=df[(df["cat"]=="Node")&df["name"].str.contains("kernel_time")]
133+
groupkey.append("idx")
134+
forcingroupkey:
135+
ifc!="idx":
136+
df[c]=df[c].apply(str)
137+
gr=df[groupkey+ ["dur"]].groupby(groupkey)
138+
returngr.sum()
139+
140+
141+
base=preprocess(prof_base)
142+
opti=preprocess(prof_opt)
143+
merge=base.merge(
144+
opti,how="outer",suffixes=("base","opti"),left_index=True,right_index=True
145+
)
146+
merge=merge.reset_index(drop=False)
147+
merge.to_excel("plot_profiling_merged.xlsx",index=False)
148+
merge
149+
150+
151+
#####################################################
152+
# Aggregation
153+
154+
155+
defclassify(row):
156+
ifnumpy.isnan(row["duropti"]):
157+
return"-"
158+
ifnumpy.isnan(row["durbase"]):
159+
return"+"
160+
return"="
161+
162+
163+
keys= {"float":"f"}
164+
165+
166+
defprocess_shape(s):
167+
value=eval(s)
168+
ns= []
169+
forvinvalue:
170+
iflen(v)!=1:
171+
raiseNotImplementedError(f"Unexpected value{v} in{s!r}.")
172+
k,v=list(v.items())[0]
173+
n="-".join([keys[k],"x".join(map(str,v))])
174+
ns.append(n)
175+
return",".join(ns)
176+
177+
178+
deflabel(row):
179+
name=row["args_op_name"]
180+
inshape=process_shape(row["args_input_type_shape"])
181+
outshape=process_shape(row["args_output_type_shape"])
182+
side=row["side"][0]
183+
prov=row["args_provider"][:3]
184+
idx=row["idx"]
185+
returnf"[{side}{prov}]{name}({inshape})->{outshape}[{idx}]"
186+
187+
188+
df=merge.copy()
189+
df["side"]=df.apply(classify,axis=1)
190+
df["label"]=df.apply(label,axis=1)
191+
gr= (
192+
df[["label","durbase","duropti","idx"]]
193+
.groupby("label")
194+
.agg({"durbase":numpy.sum,"duropti":numpy.sum,"idx":max})
195+
)
196+
gr
197+
198+
################################
199+
# Final plot
200+
# ++++++++++
201+
202+
# let's filter out unsignificant operator.
203+
grmax=gr["durbase"]+gr["duropti"]
204+
total=grmax.sum()
205+
grmax/=total
206+
gr=gr[grmax>=0.01]
207+
208+
209+
fig,ax=plt.subplots(1,2,figsize=(14,min(gr.shape[0],500)),sharey=True)
210+
gr[["durbase","duropti"]].plot.barh(ax=ax[0])
211+
ax[0].set_title("Side by side duration")
212+
gr=gr.copy()
213+
gr["idx"]+=1
214+
gr[["idx"]].plot.barh(ax=ax[1])
215+
ax[1].set_title("Side by side count")
216+
fig.tight_layout()
217+
fig.savefig("plot_profiling_side_by_side.png")
218+
219+
220+
########################################
221+
# On CUDA
222+
# +++++++
223+
224+
225+
if"CUDAExecutionProvider"inget_available_providers():
226+
print("Profiling on CUDA")
227+
prof_base=ort_profile(
228+
filename,
229+
feeds,
230+
repeat=6,
231+
disable_optimization=True,
232+
providers=["CUDAExecutionProvider"],
233+
)
234+
prof_opti=ort_profile(
235+
optimized,
236+
feeds,
237+
repeat=6,
238+
disable_optimization=True,
239+
providers=["CUDAExecutionProvider"],
240+
)
241+
242+
unique_op=set(prof_base["args_op_name"])
243+
fig,ax=plt.subplots(2,2,figsize=(10,len(unique_op)),sharex="col")
244+
plot_profile(prof_base,ax[0,0],ax[0,1],title="baseline")
245+
plot_profile(prof_opt,ax[1,0],ax[1,1],title="optimized")
246+
fig.savefig("plot_profiling_cuda.png")
247+
else:
248+
print(f"CUDA not available in{get_available_providers()}")
249+
fig,ax=None,None
250+
251+
ax
5.84 KB
Binary file not shown.

‎_doc/examples/prof_base.xlsx

29.8 KB
Binary file not shown.

‎_unittests/ut_ort/test_ort_profile.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
importunittest
2+
importnumpyasnp
3+
frompandasimportDataFrame
4+
fromonnx_array_api.npximportabsolute,jit_onnx
5+
fromonnx_array_api.ext_test_caseimportExtTestCase
6+
fromonnx_array_api.ort.ort_optimizersimportort_optimized_model
7+
fromonnx_array_api.ort.ort_profileimportort_profile
8+
9+
10+
classTestOrtProfile(ExtTestCase):
11+
deftest_ort_profile(self):
12+
defl1_loss(x,y):
13+
returnabsolute(x-y).sum()
14+
15+
defl2_loss(x,y):
16+
return ((x-y)**2).sum()
17+
18+
defmyloss(x,y):
19+
returnl1_loss(x[:,0],y[:,0])+l2_loss(x[:,1],y[:,1])
20+
21+
jitted_myloss=jit_onnx(myloss)
22+
x=np.array([[0.1,0.2], [0.3,0.4]],dtype=np.float32)
23+
y=np.array([[0.11,0.22], [0.33,0.44]],dtype=np.float32)
24+
jitted_myloss(x,y)
25+
onx=jitted_myloss.get_onnx()
26+
feeds= {"x0":x,"x1":y}
27+
self.assertRaise(lambda:ort_optimized_model(onx,"NO"),ValueError)
28+
optimized=ort_optimized_model(onx)
29+
prof=ort_profile(optimized,feeds)
30+
self.assertIsInstance(prof,DataFrame)
31+
prof=ort_profile(optimized,feeds,as_df=False)
32+
self.assertIsInstance(prof,list)
33+
34+
35+
if__name__=="__main__":
36+
unittest.main(verbosity=2)

‎onnx_array_api/ort/ort_optimizers.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,11 @@ def ort_optimized_model(
3333
so=SessionOptions()
3434
so.graph_optimization_level=glevel
3535
so.optimized_model_filepath=str(cache)
36-
InferenceSession(onxifisinstance(onx,str)elseonx.SerializeToString(),so)
36+
InferenceSession(
37+
onxifisinstance(onx,str)elseonx.SerializeToString(),
38+
so,
39+
providers=["CPUExecutionProvider"],
40+
)
3741
ifoutputisNoneandnotcache.exists():
3842
raiseRuntimeError(f"The optimized model{str(cache)!r} not found.")
3943
ifoutputisnotNone:

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp