Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit75d62a0

Browse files
authored
Add an export to convert an onnx graph into light API code (#46)
* Add an export to convert an onnx graph into light API code* fix unit tests* fix annotations* fix documentation* doc
1 parentdd11424 commit75d62a0

File tree

9 files changed

+489
-11
lines changed

9 files changed

+489
-11
lines changed

‎CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.1.3
55
+++++
66

7+
*:pr:`46`: adds an export to convert an onnx graph into light API code
78
*:pr:`45`: fixes light API for operators with two outputs
89

910
0.1.2

‎README.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,4 +141,4 @@ The euclidean distance looks like the following:
141141
The library is released on
142142
`pypi/onnx-array-api<https://pypi.org/project/onnx-array-api/>`_
143143
and its documentation is published at
144-
`(Numpy) Array API forONNX<https://sdpython.github.io/doc/onnx-array-api/dev/>`_.
144+
`APIs to createONNX Graphs<https://sdpython.github.io/doc/onnx-array-api/dev/>`_.

‎_doc/api/light_api.rst

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,33 +2,67 @@
22
onnx_array_api.light_api
33
========================
44

5+
6+
Main API
7+
========
8+
59
start
6-
=====
10+
+++++
711

812
..autofunction::onnx_array_api.light_api.start
913

14+
translate
15+
+++++++++
16+
17+
..autofunction::onnx_array_api.light_api.translate
18+
19+
Classes for the Light API
20+
=========================
21+
1022
OnnxGraph
11-
=========
23+
+++++++++
1224

1325
..autoclass::onnx_array_api.light_api.OnnxGraph
1426
:members:
1527

1628
BaseVar
17-
=======
29+
+++++++
1830

1931
..autoclass::onnx_array_api.light_api.var.BaseVar
2032
:members:
2133

2234
Var
23-
===
35+
+++
2436

2537
..autoclass::onnx_array_api.light_api.Var
2638
:members:
2739
:inherited-members:
2840

2941
Vars
30-
====
42+
++++
3143

3244
..autoclass::onnx_array_api.light_api.Vars
3345
:members:
3446
:inherited-members:
47+
48+
Classes for the Translater
49+
==========================
50+
51+
Emitter
52+
+++++++
53+
54+
..autoclass::onnx_array_api.light_api.translate.Emitter
55+
:members:
56+
57+
EventType
58+
+++++++++
59+
60+
..autoclass::onnx_array_api.light_api.translate.EventType
61+
:members:
62+
63+
Translater
64+
++++++++++
65+
66+
..autoclass::onnx_array_api.light_api.translate.Translater
67+
:members:
68+

‎_doc/index.rst

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ The objective is to speed up the implementation of converter libraries.
4545
CHANGELOGS
4646
license
4747

48-
**Numpy API**
48+
Numpy API
49+
+++++++++
4950

5051
Sources available on
5152
`github/onnx-array-api<https://github.com/sdpython/onnx-array-api>`_.
@@ -109,7 +110,8 @@ Sources available on
109110
res = jitted_myloss(x, y)
110111
print(to_dot(jitted_myloss.get_onnx()))
111112

112-
**Light API**
113+
Light API
114+
+++++++++
113115

114116
..runpython::
115117
:showcode:
@@ -135,3 +137,9 @@ Sources available on
135137
)
136138

137139
print(onnx_simple_text_plot(model))
140+
141+
142+
Older versions
143+
++++++++++++++
144+
145+
* `0.1.2<../v0.1.2/index.html>`_
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
importunittest
2+
fromtextwrapimportdedent
3+
importnumpyasnp
4+
fromonnximportModelProto,TensorProto
5+
fromonnx.defsimportonnx_opset_version
6+
fromonnx.referenceimportReferenceEvaluator
7+
fromonnx_array_api.ext_test_caseimportExtTestCase
8+
fromonnx_array_api.light_apiimportstart,translate
9+
10+
OPSET_API=min(19,onnx_opset_version()-1)
11+
12+
13+
classTestTranslate(ExtTestCase):
14+
deftest_exp(self):
15+
onx=start(opset=19).vin("X").Exp().rename("Y").vout().to_onnx()
16+
self.assertIsInstance(onx,ModelProto)
17+
self.assertIn("Exp",str(onx))
18+
ref=ReferenceEvaluator(onx)
19+
a=np.arange(10).astype(np.float32)
20+
got=ref.run(None, {"X":a})[0]
21+
self.assertEqualArray(np.exp(a),got)
22+
23+
code=translate(onx)
24+
expected=dedent(
25+
"""
26+
(
27+
start(opset=19)
28+
.vin('X', elem_type=TensorProto.FLOAT)
29+
.bring('X')
30+
.Exp()
31+
.rename('Y')
32+
.bring('Y')
33+
.vout(elem_type=TensorProto.FLOAT)
34+
.to_onnx()
35+
)"""
36+
).strip("\n")
37+
self.assertEqual(expected,code)
38+
39+
onx2= (
40+
start(opset=19)
41+
.vin("X",elem_type=TensorProto.FLOAT)
42+
.bring("X")
43+
.Exp()
44+
.rename("Y")
45+
.bring("Y")
46+
.vout(elem_type=TensorProto.FLOAT)
47+
.to_onnx()
48+
)
49+
ref=ReferenceEvaluator(onx2)
50+
a=np.arange(10).astype(np.float32)
51+
got=ref.run(None, {"X":a})[0]
52+
self.assertEqualArray(np.exp(a),got)
53+
54+
deftest_transpose(self):
55+
onx= (
56+
start(opset=19)
57+
.vin("X")
58+
.reshape((-1,1))
59+
.Transpose(perm=[1,0])
60+
.rename("Y")
61+
.vout()
62+
.to_onnx()
63+
)
64+
self.assertIsInstance(onx,ModelProto)
65+
self.assertIn("Transpose",str(onx))
66+
ref=ReferenceEvaluator(onx)
67+
a=np.arange(10).astype(np.float32)
68+
got=ref.run(None, {"X":a})[0]
69+
self.assertEqualArray(a.reshape((-1,1)).T,got)
70+
71+
code=translate(onx)
72+
expected=dedent(
73+
"""
74+
(
75+
start(opset=19)
76+
.vin('X', elem_type=TensorProto.FLOAT)
77+
.bring('X', 'r')
78+
.Reshape()
79+
.rename('r0_0')
80+
.bring('r0_0')
81+
.Transpose(perm=[1, 0])
82+
.rename('Y')
83+
.bring('Y')
84+
.vout(elem_type=TensorProto.FLOAT)
85+
.to_onnx()
86+
)"""
87+
).strip("\n")
88+
self.assertEqual(expected,code)
89+
90+
deftest_topk_reverse(self):
91+
onx= (
92+
start(opset=19)
93+
.vin("X",np.float32)
94+
.vin("K",np.int64)
95+
.bring("X","K")
96+
.TopK(largest=0)
97+
.rename("Values","Indices")
98+
.vout()
99+
.to_onnx()
100+
)
101+
self.assertIsInstance(onx,ModelProto)
102+
ref=ReferenceEvaluator(onx)
103+
x=np.array([[0,1,2,3], [9,8,7,6]],dtype=np.float32)
104+
k=np.array([2],dtype=np.int64)
105+
got=ref.run(None, {"X":x,"K":k})
106+
self.assertEqualArray(np.array([[0,1], [6,7]],dtype=np.float32),got[0])
107+
self.assertEqualArray(np.array([[0,1], [3,2]],dtype=np.int64),got[1])
108+
109+
code=translate(onx)
110+
expected=dedent(
111+
"""
112+
(
113+
start(opset=19)
114+
.vin('X', elem_type=TensorProto.FLOAT)
115+
.vin('K', elem_type=TensorProto.INT64)
116+
.bring('X', 'K')
117+
.TopK(axis=-1, largest=0, sorted=1)
118+
.rename('Values', 'Indices')
119+
.bring('Values')
120+
.vout(elem_type=TensorProto.FLOAT)
121+
.bring('Indices')
122+
.vout(elem_type=TensorProto.FLOAT)
123+
.to_onnx()
124+
)"""
125+
).strip("\n")
126+
self.assertEqual(expected,code)
127+
128+
129+
if__name__=="__main__":
130+
# TestLightApi().test_topk()
131+
unittest.main(verbosity=2)

‎onnx_array_api/light_api/__init__.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
fromtypingimportDict,Optional
2+
fromonnximportModelProto
23
from .modelimportOnnxGraph
4+
from .translateimportTranslater
35
from .varimportVar,Vars
46

57

@@ -34,8 +36,48 @@ def start(
3436
from onnx_array_api.light_api import start
3537
3638
onx = (
37-
start().vin("X").vin("Y").bring("X", "Y").Add().rename("Z").vout().to_onnx()
39+
start()
40+
.vin("X")
41+
.vin("Y")
42+
.bring("X", "Y")
43+
.Add()
44+
.rename("Z")
45+
.vout()
46+
.to_onnx()
3847
)
3948
print(onx)
4049
"""
4150
returnOnnxGraph(opset=opset,opsets=opsets,is_function=is_function)
51+
52+
53+
deftranslate(proto:ModelProto,single_line=False)->str:
54+
"""
55+
Translates an ONNX proto into a code using :ref:`l-light-api`
56+
to describe the ONNX graph.
57+
58+
:param proto: model to translate
59+
:param single_line: as a single line or not
60+
:return: code
61+
62+
.. runpython::
63+
:showcode:
64+
65+
from onnx_array_api.light_api import start, translate
66+
67+
onx = (
68+
start()
69+
.vin("X")
70+
.reshape((-1, 1))
71+
.Transpose(perm=[1, 0])
72+
.rename("Y")
73+
.vout()
74+
.to_onnx()
75+
)
76+
code = translate(onx)
77+
print(code)
78+
"""
79+
tr=Translater(proto)
80+
rows=tr.export()
81+
ifsingle_line:
82+
return".".join(rows)
83+
return"".join(["(\n ","\n .".join(rows),"\n)"])

‎onnx_array_api/light_api/annotations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
ELEMENT_TYPE_NAME= {
1313
getattr(TensorProto,k):k
1414
forkindir(TensorProto)
15-
ifisinstance(getattr(TensorProto,k),int)
15+
ifisinstance(getattr(TensorProto,k),int)and"_"notink
1616
}
1717

1818
_type_numpy= {

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp