Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit9de394e

Browse files
xadupresdpython
andauthored
Extends export onnx to code to support inner API (#47)
* Extend to use inner API* export subgraphs* update code* refactoring* add more tests* fix conversion* fix ut* fix ut* fix doc* doc* verbostiy* disable unstable test---------Co-authored-by: Xavier Dupré <xavier.dupre@gmail.com>
1 parent75d62a0 commit9de394e

File tree

14 files changed

+1049
-129
lines changed

14 files changed

+1049
-129
lines changed

‎CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.1.3
55
+++++
66

7+
*:pr:`47`: extends export onnx to code to support inner API
78
*:pr:`46`: adds an export to convert an onnx graph into light API code
89
*:pr:`45`: fixes light API for operators with two outputs
910

‎_doc/api/light_api.rst

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,16 @@ Vars
4848
Classes for the Translater
4949
==========================
5050

51+
BaseEmitter
52+
+++++++++++
53+
54+
..autoclass::onnx_array_api.light_api.emitter.BaseEmitter
55+
:members:
56+
5157
Emitter
5258
+++++++
5359

54-
..autoclass::onnx_array_api.light_api.translate.Emitter
60+
..autoclass::onnx_array_api.light_api.emitter.Emitter
5561
:members:
5662

5763
EventType
@@ -60,6 +66,12 @@ EventType
6066
..autoclass::onnx_array_api.light_api.translate.EventType
6167
:members:
6268

69+
InnerEmitter
70+
++++++++++++
71+
72+
..autoclass::onnx_array_api.light_api.inner_emitter.InnerEmitter
73+
:members:
74+
6375
Translater
6476
++++++++++
6577

Binary file not shown.
Lines changed: 290 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
importunittest
2+
fromtypingimportAny,Dict,List,Optional
3+
fromdifflibimportunified_diff
4+
importpackaging.versionaspv
5+
importnumpy
6+
fromnumpy.testingimportassert_allclose
7+
importonnx.backend.base
8+
importonnx.backend.test
9+
importonnx.shape_inference
10+
importonnx.version_converter
11+
fromonnximportModelProto,TensorProto,__version__asonnx_version
12+
fromonnx.helperimport (
13+
make_function,
14+
make_graph,
15+
make_model,
16+
make_node,
17+
make_opsetid,
18+
make_tensor_value_info,
19+
)
20+
fromonnx.numpy_helperimportfrom_array,to_array
21+
fromonnx.backend.baseimportDevice,DeviceType
22+
fromonnx_array_api.referenceimportExtendedReferenceEvaluator
23+
fromonnx_array_api.light_apiimporttranslate
24+
fromonnx_array_api.plotting.text_plotimportonnx_simple_text_plot
25+
26+
27+
classReferenceImplementationError(RuntimeError):
28+
"Fails, export cannot be compared."
29+
pass
30+
31+
32+
classExportWrapper:
33+
apis= ["onnx","light"]
34+
35+
def__init__(self,model):
36+
self.model=model
37+
self.expected_sess=ExtendedReferenceEvaluator(self.model)
38+
39+
@property
40+
definput_names(self):
41+
returnself.expected_sess.input_names
42+
43+
@property
44+
definput_types(self):
45+
returnself.expected_sess.input_types
46+
47+
@property
48+
defoutput_names(self):
49+
returnself.expected_sess.output_names
50+
51+
@property
52+
defoutput_types(self):
53+
returnself.expected_sess.output_types
54+
55+
defrun(
56+
self,names:Optional[List[str]],feeds:Optional[Dict[str,Any]]=None
57+
)->List[Any]:
58+
try:
59+
expected=self.expected_sess.run(names,feeds)
60+
except (RuntimeError,AssertionError,TypeError,KeyError)ase:
61+
raiseReferenceImplementationError(
62+
f"ReferenceImplementation fails with{onnx_simple_text_plot(self.model)}"
63+
f"\n--RAW--\n{self.model}"
64+
)frome
65+
66+
forapiinself.apis:
67+
try:
68+
code=translate(self.model,api=api)
69+
exceptNotImplementedError:
70+
continue
71+
exceptValueErrorase:
72+
raiseAssertionError(
73+
f"Unable to translate model for api{api!r}, "
74+
f"\n--BASE--\n{onnx_simple_text_plot(self.model)}"
75+
f"\n--EXPECTED--\n{expected}"
76+
)frome
77+
try:
78+
code_compiled=compile(code,"<string>",mode="exec")
79+
exceptExceptionase:
80+
new_code="\n".join(
81+
[f"{i+1:04}{line}"fori,lineinenumerate(code.split("\n"))]
82+
)
83+
raiseAssertionError(f"ERROR{e}\n{new_code}")
84+
85+
locs= {
86+
"np":numpy,
87+
"to_array":to_array,
88+
"from_array":from_array,
89+
"TensorProto":TensorProto,
90+
"make_function":make_function,
91+
"make_opsetid":make_opsetid,
92+
"make_model":make_model,
93+
"make_graph":make_graph,
94+
"make_node":make_node,
95+
"make_tensor_value_info":make_tensor_value_info,
96+
}
97+
globs=locs.copy()
98+
try:
99+
exec(code_compiled,globs,locs)
100+
except (TypeError,NameError,ValueError)ase:
101+
new_code="\n".join(
102+
[f"{i+1:04}{line}"fori,lineinenumerate(code.split("\n"))]
103+
)
104+
raiseAssertionError(
105+
f"Unable to executed code for api{api!r}\n{new_code}"
106+
)frome
107+
export_model=locs["model"]
108+
ref=ExtendedReferenceEvaluator(export_model)
109+
try:
110+
got=ref.run(names,feeds)
111+
except (TypeError,AttributeError)ase:
112+
diff="\n".join(
113+
unified_diff(
114+
str(self.model).split("\n"),
115+
str(export_model).split("\n"),
116+
fromfile="before",
117+
tofile="after",
118+
)
119+
)
120+
raiseAssertionError(
121+
f"Unable to run the exported model for api{api!r}, "
122+
f"\n--BASE--\n{onnx_simple_text_plot(self.model)}"
123+
f"\n--EXP[{api}]--\n{onnx_simple_text_plot(export_model)}"
124+
f"\n--CODE--\n{code}"
125+
f"\n--FEEDS--\n{feeds}"
126+
f"\n--EXPECTED--\n{expected}"
127+
f"\n--DIFF--\n{diff}"
128+
)frome
129+
iflen(expected)!=len(got):
130+
raiseAssertionError(
131+
f"Unexpected number of outputs for api{api!r}, "
132+
f"{len(expected)} !={len(got)}."
133+
f"\n--BASE--\n{onnx_simple_text_plot(self.model)}"
134+
f"\n--EXP[{api}]--\n{onnx_simple_text_plot(export_model)}"
135+
)
136+
fora,binzip(expected,got):
137+
ifnotisinstance(a,numpy.ndarray):
138+
continue
139+
ifa.shape!=b.shapeora.dtype!=b.dtype:
140+
raiseAssertionError(
141+
f"Shape or type discrepancies for api{api!r}."
142+
f"\n--BASE--\n{onnx_simple_text_plot(self.model)}"
143+
f"\n--EXP[{api}]--\n{onnx_simple_text_plot(export_model)}"
144+
)
145+
ifa.dtypein (numpy.str_,object,numpy.object_)orisinstance(
146+
a.dtype,getattr(getattr(numpy,"dtypes",None),"StrDType",type)
147+
):
148+
ifa.tolist()!=b.tolist():
149+
raiseAssertionError(
150+
f"Text discrepancies for api{api!r} with a.dtype={a.dtype} "
151+
f"and b.dtype={b.dtype}"
152+
f"\n--BASE--\n{onnx_simple_text_plot(self.model)}"
153+
f"\n--EXP[{api}]--\n{onnx_simple_text_plot(export_model)}"
154+
)
155+
continue
156+
try:
157+
assert_allclose(a,b,atol=1e-3)
158+
except (AssertionError,TypeError)ase:
159+
raiseAssertionError(
160+
f"Discrepancies for api{api!r} with a.dtype={a.dtype} "
161+
f"and b.dtype={b.dtype} (type-dtype={type(a.dtype)})"
162+
f"\n--BASE--\n{onnx_simple_text_plot(self.model)}"
163+
f"\n--EXP[{api}]--\n{onnx_simple_text_plot(export_model)}"
164+
)frome
165+
166+
returnexpected
167+
168+
169+
classExportBackendRep(onnx.backend.base.BackendRep):
170+
def__init__(self,session):
171+
self._session=session
172+
173+
defrun(self,inputs,**kwargs):
174+
ifisinstance(inputs,numpy.ndarray):
175+
inputs= [inputs]
176+
ifisinstance(inputs,list):
177+
iflen(inputs)==len(self._session.input_names):
178+
feeds=dict(zip(self._session.input_names,inputs))
179+
else:
180+
feeds= {}
181+
pos_inputs=0
182+
forinp,tshapeinzip(
183+
self._session.input_names,self._session.input_types
184+
):
185+
shape=tuple(d.dim_valuefordintshape.tensor_type.shape.dim)
186+
ifshape==inputs[pos_inputs].shape:
187+
feeds[inp]=inputs[pos_inputs]
188+
pos_inputs+=1
189+
ifpos_inputs>=len(inputs):
190+
break
191+
elifisinstance(inputs,dict):
192+
feeds=inputs
193+
else:
194+
raiseTypeError(f"Unexpected input type{type(inputs)!r}.")
195+
outs=self._session.run(None,feeds)
196+
returnouts
197+
198+
199+
classExportBackend(onnx.backend.base.Backend):
200+
@classmethod
201+
defis_opset_supported(cls,model):# pylint: disable=unused-argument
202+
returnTrue,""
203+
204+
@classmethod
205+
defsupports_device(cls,device:str)->bool:
206+
d=Device(device)
207+
returnd.type==DeviceType.CPU# type: ignore[no-any-return]
208+
209+
@classmethod
210+
defcreate_inference_session(cls,model):
211+
returnExportWrapper(model)
212+
213+
@classmethod
214+
defprepare(
215+
cls,model:Any,device:str="CPU",**kwargs:Any
216+
)->ExportBackendRep:
217+
ifisinstance(model,ExportWrapper):
218+
returnExportBackendRep(model)
219+
ifisinstance(model, (str,bytes,ModelProto)):
220+
inf=cls.create_inference_session(model)
221+
returncls.prepare(inf,device,**kwargs)
222+
raiseTypeError(f"Unexpected type{type(model)} for model.")
223+
224+
@classmethod
225+
defrun_model(cls,model,inputs,device=None,**kwargs):
226+
rep=cls.prepare(model,device,**kwargs)
227+
returnrep.run(inputs,**kwargs)
228+
229+
@classmethod
230+
defrun_node(cls,node,inputs,device=None,outputs_info=None,**kwargs):
231+
raiseNotImplementedError("Unable to run the model node by node.")
232+
233+
234+
backend_test=onnx.backend.test.BackendTest(ExportBackend,__name__)
235+
236+
# The following tests are too slow with the reference implementation (Conv).
237+
backend_test.exclude(
238+
"(FLOAT8|BFLOAT16|_opt_|_3d_|_momentum_|_4d_"
239+
"|test_adagrad"
240+
"|test_adam"
241+
"|test_ai_onnx_ml_"
242+
"|test_cast_FLOAT16"
243+
"|test_cast_FLOAT_to_STRING"
244+
"|test_castlike_FLOAT16"
245+
"|test_castlike_FLOAT_to_STRING"
246+
"|test_bernoulli"
247+
"|test_bvlc_alexnet"
248+
"|test_conv"# too long
249+
"|test_gradient_"
250+
"|test_densenet121"
251+
"|test_inception_v1"
252+
"|test_inception_v2"
253+
"|test_loop11_"
254+
"|test_loop16_seq_none"
255+
"|test_MaxPool2d"
256+
"|test_quantizelinear_e"
257+
"|test_resnet50"
258+
"|test_sequence_model"
259+
"|test_scan_sum"
260+
"|test_scatter_with_axis"
261+
"|test_scatter_without_axis"
262+
"|test_shufflenet"
263+
"|test_squeezenet"
264+
"|test_vgg19"
265+
"|test_zfnet512"
266+
")"
267+
)
268+
269+
ifpv.Version(onnx_version)<pv.Version("1.16.0"):
270+
backend_test.exclude("(test_strnorm|test_range_)")
271+
272+
# The following tests cannot pass because they consists in generating random number.
273+
backend_test.exclude("(test_bernoulli)")
274+
275+
# import all test cases at global scope to make them visible to python.unittest
276+
globals().update(backend_test.test_cases)
277+
278+
if__name__=="__main__":
279+
res=unittest.main(verbosity=2,exit=False)
280+
tests_run=res.result.testsRun
281+
errors=len(res.result.errors)
282+
skipped=len(res.result.skipped)
283+
unexpected_successes=len(res.result.unexpectedSuccesses)
284+
expected_failures=len(res.result.expectedFailures)
285+
print("---------------------------------")
286+
print(
287+
f"tests_run={tests_run} errors={errors} skipped={skipped} "
288+
f"unexpected_successes={unexpected_successes} "
289+
f"expected_failures={expected_failures}"
290+
)

‎_unittests/ut_light_api/test_light_api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
importunittest
2+
importsys
23
fromtypingimportCallable,Optional
34
importnumpyasnp
45
fromonnximportModelProto
@@ -144,6 +145,7 @@ def list_ops_missing(self, n_inputs):
144145
f"{new_missing}\n{text}"
145146
)
146147

148+
@unittest.skipIf(sys.platform=="win32",reason="unstable test on Windows")
147149
deftest_list_ops_missing(self):
148150
self.list_ops_missing(1)
149151
self.list_ops_missing(2)

‎_unittests/ut_light_api/test_translate.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,17 @@
66
fromonnx.referenceimportReferenceEvaluator
77
fromonnx_array_api.ext_test_caseimportExtTestCase
88
fromonnx_array_api.light_apiimportstart,translate
9+
fromonnx_array_api.light_api.emitterimportEventType
910

1011
OPSET_API=min(19,onnx_opset_version()-1)
1112

1213

1314
classTestTranslate(ExtTestCase):
15+
deftest_event_type(self):
16+
self.assertEqual(
17+
EventType.to_str(EventType.INITIALIZER),"EventType.INITIALIZER"
18+
)
19+
1420
deftest_exp(self):
1521
onx=start(opset=19).vin("X").Exp().rename("Y").vout().to_onnx()
1622
self.assertIsInstance(onx,ModelProto)
@@ -73,6 +79,8 @@ def test_transpose(self):
7379
"""
7480
(
7581
start(opset=19)
82+
.cst(np.array([-1, 1], dtype=np.int64))
83+
.rename('r')
7684
.vin('X', elem_type=TensorProto.FLOAT)
7785
.bring('X', 'r')
7886
.Reshape()

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp