Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit8b54ad1

Browse files
authored
Supports other domain for light API (#54)
* ut* first sketch* finalize other domain epxressions* docuemntation* extend the support of translate to other domain* documentation
1 parent06a15a9 commit8b54ad1

File tree

12 files changed

+333
-10
lines changed

12 files changed

+333
-10
lines changed

‎_doc/api/light_api.rst

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,15 @@ translate
1919
Classes for the Light API
2020
=========================
2121

22-
ProtoType
23-
+++++++++
22+
domain
23+
++++++
2424

25-
..autoclass::onnx_array_api.light_api.model.ProtoType
25+
..autofunction:: onnx_array_api.light_api.domain
26+
27+
BaseVar
28+
+++++++
29+
30+
..autoclass::onnx_array_api.light_api.var.BaseVar
2631
:members:
2732

2833
OnnxGraph
@@ -31,10 +36,16 @@ OnnxGraph
3136
..autoclass::onnx_array_api.light_api.OnnxGraph
3237
:members:
3338

34-
BaseVar
35-
+++++++
39+
ProtoType
40+
+++++++++
3641

37-
..autoclass::onnx_array_api.light_api.var.BaseVar
42+
..autoclass::onnx_array_api.light_api.model.ProtoType
43+
:members:
44+
45+
SubDomain
46+
+++++++++
47+
48+
..autoclass::onnx_array_api.light_api.var.SubDomain
3849
:members:
3950

4051
Var

‎_doc/tutorial/light_api.rst

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,32 @@ operator `+` to be available as well and that the case. They are
7676
defined in class:class:`Var <onnx_array_api.light_api.Var>` or
7777
:class:`Vars <onnx_array_api.light_api.Vars>` depending on the number of
7878
inputs they require. Their name starts with a lower letter.
79+
80+
Other domains
81+
=============
82+
83+
The following example uses operator *Normalizer* from domain
84+
*ai.onnx.ml*. The operator name is called with the syntax
85+
`<domain>.<operator name>`. The domain may have dots in its name
86+
but it must follow the python definition of a variable.
87+
The operator *Normalizer* becomes `ai.onnx.ml.Normalizer`.
88+
89+
..runpython::
90+
:showcode:
91+
92+
import numpy as np
93+
from onnx_array_api.light_api import start
94+
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
95+
96+
model = (
97+
start(opset=19, opsets={"ai.onnx.ml": 3})
98+
.vin("X")
99+
.reshape((-1, 1))
100+
.rename("USE")
101+
.ai.onnx.ml.Normalizer(norm="MAX")
102+
.rename("Y")
103+
.vout()
104+
.to_onnx()
105+
)
106+
107+
print(onnx_simple_text_plot(model))

‎_unittests/ut_light_api/test_light_api.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
importinspect
12
importunittest
23
fromtypingimportCallable,Optional
34
importnumpyasnp
@@ -12,6 +13,7 @@
1213
fromonnx.referenceimportReferenceEvaluator
1314
fromonnx_array_api.ext_test_caseimportExtTestCase,skipif_ci_windows
1415
fromonnx_array_api.light_apiimportstart,OnnxGraph,Var,g
16+
fromonnx_array_api.light_api.varimportSubDomain
1517
fromonnx_array_api.light_api._op_varimportOpsVar
1618
fromonnx_array_api.light_api._op_varsimportOpsVars
1719

@@ -472,7 +474,43 @@ def test_if(self):
472474
got=ref.run(None, {"X":-x})
473475
self.assertEqualArray(np.array([0],dtype=np.int64),got[0])
474476

477+
deftest_domain(self):
478+
onx=start(opsets={"ai.onnx.ml":3}).vin("X").reshape((-1,1)).rename("USE")
479+
480+
classA:
481+
defg(self):
482+
returnTrue
483+
484+
defah(self):
485+
returnTrue
486+
487+
setattr(A,"h",ah)
488+
489+
self.assertTrue(A().h())
490+
self.assertIn("(self)",str(inspect.signature(A.h)))
491+
self.assertTrue(issubclass(onx._ai,SubDomain))
492+
self.assertIsInstance(onx.ai,SubDomain)
493+
self.assertIsInstance(onx.ai.parent,Var)
494+
self.assertTrue(issubclass(onx._ai._onnx,SubDomain))
495+
self.assertIsInstance(onx.ai.onnx,SubDomain)
496+
self.assertIsInstance(onx.ai.onnx.parent,Var)
497+
self.assertTrue(issubclass(onx._ai._onnx._ml,SubDomain))
498+
self.assertIsInstance(onx.ai.onnx.ml,SubDomain)
499+
self.assertIsInstance(onx.ai.onnx.ml.parent,Var)
500+
self.assertIn("(self,",str(inspect.signature(onx._ai._onnx._ml.Normalizer)))
501+
onx=onx.ai.onnx.ml.Normalizer(norm="MAX")
502+
onx=onx.rename("Y").vout().to_onnx()
503+
self.assertIsInstance(onx,ModelProto)
504+
self.assertIn("Normalizer",str(onx))
505+
self.assertIn('domain: "ai.onnx.ml"',str(onx))
506+
self.assertIn('input: "USE"',str(onx))
507+
ref=ReferenceEvaluator(onx)
508+
a=np.arange(10).astype(np.float32)
509+
got=ref.run(None, {"X":a})[0]
510+
expected= (a>0).astype(int).astype(np.float32).reshape((-1,1))
511+
self.assertEqualArray(expected,got)
512+
475513

476514
if__name__=="__main__":
477-
TestLightApi().test_if()
515+
TestLightApi().test_domain()
478516
unittest.main(verbosity=2)

‎_unittests/ut_light_api/test_translate.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,39 @@ def test_export_if(self):
185185
self.maxDiff=None
186186
self.assertEqual(expected,code)
187187

188+
deftest_aionnxml(self):
189+
onx= (
190+
start(opset=19,opsets={"ai.onnx.ml":3})
191+
.vin("X")
192+
.reshape((-1,1))
193+
.rename("USE")
194+
.ai.onnx.ml.Normalizer(norm="MAX")
195+
.rename("Y")
196+
.vout()
197+
.to_onnx()
198+
)
199+
code=translate(onx)
200+
expected=dedent(
201+
"""
202+
(
203+
start(opset=19, opsets={'ai.onnx.ml': 3})
204+
.cst(np.array([-1, 1], dtype=np.int64))
205+
.rename('r')
206+
.vin('X', elem_type=TensorProto.FLOAT)
207+
.bring('X', 'r')
208+
.Reshape()
209+
.rename('USE')
210+
.bring('USE')
211+
.ai.onnx.ml.Normalizer(norm='MAX')
212+
.rename('Y')
213+
.bring('Y')
214+
.vout(elem_type=TensorProto.FLOAT)
215+
.to_onnx()
216+
)"""
217+
).strip("\n")
218+
self.maxDiff=None
219+
self.assertEqual(expected,code)
220+
188221

189222
if__name__=="__main__":
190223
TestTranslate().test_export_if()

‎_unittests/ut_light_api/test_translate_classic.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,72 @@ def test_fft(self):
252252
)
253253
raiseAssertionError(f"ERROR{e}\n{new_code}")
254254

255+
deftest_aionnxml(self):
256+
onx= (
257+
start(opset=19,opsets={"ai.onnx.ml":3})
258+
.vin("X")
259+
.reshape((-1,1))
260+
.rename("USE")
261+
.ai.onnx.ml.Normalizer(norm="MAX")
262+
.rename("Y")
263+
.vout()
264+
.to_onnx()
265+
)
266+
code=translate(onx,api="onnx")
267+
print(code)
268+
expected=dedent(
269+
"""
270+
opset_imports = [
271+
make_opsetid('', 19),
272+
make_opsetid('ai.onnx.ml', 3),
273+
]
274+
inputs = []
275+
outputs = []
276+
nodes = []
277+
initializers = []
278+
sparse_initializers = []
279+
functions = []
280+
initializers.append(
281+
from_array(
282+
np.array([-1, 1], dtype=np.int64),
283+
name='r'
284+
)
285+
)
286+
inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
287+
nodes.append(
288+
make_node(
289+
'Reshape',
290+
['X', 'r'],
291+
['USE']
292+
)
293+
)
294+
nodes.append(
295+
make_node(
296+
'Normalizer',
297+
['USE'],
298+
['Y'],
299+
domain='ai.onnx.ml',
300+
norm='MAX'
301+
)
302+
)
303+
outputs.append(make_tensor_value_info('Y', TensorProto.FLOAT, shape=[]))
304+
graph = make_graph(
305+
nodes,
306+
'light_api',
307+
inputs,
308+
outputs,
309+
initializers,
310+
sparse_initializer=sparse_initializers,
311+
)
312+
model = make_model(
313+
graph,
314+
functions=functions,
315+
opset_imports=opset_imports
316+
)"""
317+
).strip("\n")
318+
self.maxDiff=None
319+
self.assertEqual(expected,code)
320+
255321

256322
if__name__=="__main__":
257323
# TestLightApi().test_topk()

‎onnx_array_api/light_api/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
fromtypingimportDict,Optional
22
fromonnximportModelProto
3+
from .annotationsimportdomain
34
from .modelimportOnnxGraph,ProtoType
45
from .translateimportTranslater
56
from .varimportVar,Vars

‎onnx_array_api/light_api/_op_var.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
fromtypingimportList,Optional,Union
2+
from .annotationsimportAI_ONNX_ML,domain
23

34

45
classOpsVar:
@@ -319,6 +320,10 @@ def Transpose(self, perm: Optional[List[int]] = None) -> "Var":
319320
perm=permor []
320321
returnself.make_node("Transpose",self,perm=perm)
321322

323+
@domain(AI_ONNX_ML)
324+
defNormalizer(self,norm:str="MAX"):
325+
returnself.make_node("Normalizer",self,norm=norm,domain=AI_ONNX_ML)
326+
322327

323328
def_complete():
324329
ops_to_add= [

‎onnx_array_api/light_api/annotations.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
fromtypingimportTuple,Union
1+
fromtypingimportAny,Callable,Dict,List,Optional,Tuple,Union
22
importnumpyasnp
33
fromonnximportFunctionProto,GraphProto,ModelProto,TensorProto,TensorShapeProto
44
fromonnx.helperimportnp_dtype_to_tensor_dtype
@@ -9,12 +9,47 @@
99
VAR_CONSTANT_TYPE=Union["Var",TensorProto,np.ndarray]
1010
GRAPH_PROTO=Union[FunctionProto,GraphProto,ModelProto]
1111

12+
AI_ONNX_ML="ai.onnx.ml"
13+
1214
ELEMENT_TYPE_NAME= {
1315
getattr(TensorProto,k):k
1416
forkindir(TensorProto)
1517
ifisinstance(getattr(TensorProto,k),int)and"_"notink
1618
}
1719

20+
21+
classSubDomain:
22+
pass
23+
24+
25+
defdomain(domain:str,op_type:Optional[str]=None)->Callable:
26+
"""
27+
Registers one operator into a sub domain. It should be used as a
28+
decorator. One example:
29+
30+
.. code-block:: python
31+
32+
@domain("ai.onnx.ml")
33+
def Normalizer(self, norm: str = "MAX"):
34+
return self.make_node("Normalizer", self, norm=norm, domain="ai.onnx.ml")
35+
"""
36+
names= [op_type]
37+
38+
defdecorate(op_method:Callable)->Callable:
39+
ifnames[0]isNone:
40+
names[0]=op_method.__name__
41+
42+
defwrapper(self,*args:List[Any],**kwargs:Dict[str,Any])->Any:
43+
returnop_method(self.parent,*args,**kwargs)
44+
45+
wrapper.__qual__name__=f"[{domain}]{names[0]}"
46+
wrapper.__name__=f"[{domain}]{names[0]}"
47+
wrapper.__domain__=domain
48+
returnwrapper
49+
50+
returndecorate
51+
52+
1853
_type_numpy= {
1954
np.float32:TensorProto.FLOAT,
2055
np.float64:TensorProto.DOUBLE,

‎onnx_array_api/light_api/emitter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]:
241241
outputs=kwargs["outputs"]
242242
ifkwargs.get("domain","")!="":
243243
domain=kwargs["domain"]
244-
raiseNotImplementedError(f"domain={domain!r} not supported yet.")
244+
op_type=f"{domain}.{op_type}"
245245
atts=kwargs.get("atts", {})
246246
args= []
247247
fork,vinatts.items():

‎onnx_array_api/light_api/inner_emitter.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,6 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]:
120120
outputs=kwargs["outputs"]
121121
ifkwargs.get("domain","")!="":
122122
domain=kwargs["domain"]
123-
raiseNotImplementedError(f"domain={domain!r} not supported yet.")
124123

125124
before_lines= []
126125
lines= [

‎onnx_array_api/light_api/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,9 @@ def make_node(
248248

249249
node=make_node(op_type,input_names,output_names,domain=domain,**kwargs)
250250
self.nodes.append(node)
251+
ifdomain!="":
252+
ifnotself.opsetsordomainnotinself.opsets:
253+
raiseRuntimeError(f"No opset value was given for domain{domain!r}.")
251254
returnnode
252255

253256
defcst(self,value:np.ndarray,name:Optional[str]=None)->"Var":

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp