Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitbf4dba0

Browse files
committed
finalize other domain epxressions
1 parent1c14009 commitbf4dba0

File tree

5 files changed

+147
-16
lines changed

5 files changed

+147
-16
lines changed

‎_doc/api/light_api.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ OnnxGraph
3434
BaseVar
3535
+++++++
3636

37+
..autoclass::onnx_array_api.light_api.var.BaseVar
38+
:members:
39+
40+
SubDomain
41+
+++++++++
42+
3743
..autoclass::onnx_array_api.light_api.var.BaseVar
3844
:members:
3945

‎_unittests/ut_light_api/test_light_api.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
importinspect
12
importunittest
23
fromtypingimportCallable,Optional
34
importnumpyasnp
@@ -12,6 +13,7 @@
1213
fromonnx.referenceimportReferenceEvaluator
1314
fromonnx_array_api.ext_test_caseimportExtTestCase,skipif_ci_windows
1415
fromonnx_array_api.light_apiimportstart,OnnxGraph,Var,g
16+
fromonnx_array_api.light_api.varimportSubDomain
1517
fromonnx_array_api.light_api._op_varimportOpsVar
1618
fromonnx_array_api.light_api._op_varsimportOpsVars
1719

@@ -473,21 +475,40 @@ def test_if(self):
473475
self.assertEqualArray(np.array([0],dtype=np.int64),got[0])
474476

475477
deftest_domain(self):
476-
onx= (
477-
start()
478-
.vin("X")
479-
.reshape((-1,1))
480-
.ai.onnx.ml.Normalizer(norm="L1")
481-
.rename("Y")
482-
.vout()
483-
.to_onnx()
484-
)
478+
onx=start(opsets={"ai.onnx.ml":3}).vin("X").reshape((-1,1)).rename("USE")
479+
480+
classA:
481+
defg(self):
482+
returnTrue
483+
484+
defah(self):
485+
returnTrue
486+
487+
setattr(A,"h",ah)
488+
489+
self.assertTrue(A().h())
490+
self.assertIn("(self)",str(inspect.signature(A.h)))
491+
self.assertTrue(issubclass(onx._ai,SubDomain))
492+
self.assertIsInstance(onx.ai,SubDomain)
493+
self.assertIsInstance(onx.ai.parent,Var)
494+
self.assertTrue(issubclass(onx._ai._onnx,SubDomain))
495+
self.assertIsInstance(onx.ai.onnx,SubDomain)
496+
self.assertIsInstance(onx.ai.onnx.parent,Var)
497+
self.assertTrue(issubclass(onx._ai._onnx._ml,SubDomain))
498+
self.assertIsInstance(onx.ai.onnx.ml,SubDomain)
499+
self.assertIsInstance(onx.ai.onnx.ml.parent,Var)
500+
self.assertIn("(self,",str(inspect.signature(onx._ai._onnx._ml.Normalizer)))
501+
onx=onx.ai.onnx.ml.Normalizer(norm="MAX")
502+
onx=onx.rename("Y").vout().to_onnx()
485503
self.assertIsInstance(onx,ModelProto)
486-
self.assertIn("Transpose",str(onx))
504+
self.assertIn("Normalizer",str(onx))
505+
self.assertIn('domain: "ai.onnx.ml"',str(onx))
506+
self.assertIn('input: "USE"',str(onx))
487507
ref=ReferenceEvaluator(onx)
488508
a=np.arange(10).astype(np.float32)
489509
got=ref.run(None, {"X":a})[0]
490-
self.assertEqualArray(a.reshape((-1,1)).T,got)
510+
expected= (a>0).astype(int).astype(np.float32).reshape((-1,1))
511+
self.assertEqualArray(expected,got)
491512

492513

493514
if__name__=="__main__":

‎onnx_array_api/light_api/annotations.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,18 @@ def domain(domain: str, op_type: Optional[str] = None) -> Callable:
2626
"""
2727
Registers one operator into a sub domain.
2828
"""
29-
pieces=domain.split(".")
30-
sub=pieces[0]
29+
names= [op_type]
3130

3231
defdecorate(op_method:Callable)->Callable:
32+
ifnames[0]isNone:
33+
names[0]=op_method.__name__
34+
3335
defwrapper(self,*args:List[Any],**kwargs:Dict[str,Any])->Any:
34-
ifnotself.hasattr(sub):
35-
raiseRuntimeError(f"Class has not registered subdomain{sub!r}.")
36-
returnop_method(self,*args,**kwargs)
36+
returnop_method(self.parent,*args,**kwargs)
3737

38+
wrapper.__qual__name__=f"[{domain}]{names[0]}"
39+
wrapper.__name__=f"[{domain}]{names[0]}"
40+
wrapper.__domain__=domain
3841
returnwrapper
3942

4043
returndecorate

‎onnx_array_api/light_api/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,9 @@ def make_node(
248248

249249
node=make_node(op_type,input_names,output_names,domain=domain,**kwargs)
250250
self.nodes.append(node)
251+
ifdomain!="":
252+
ifnotself.opsetsordomainnotinself.opsets:
253+
raiseRuntimeError(f"No opset value was given for domain{domain!r}.")
251254
returnnode
252255

253256
defcst(self,value:np.ndarray,name:Optional[str]=None)->"Var":

‎onnx_array_api/light_api/var.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
importinspect
12
fromtypingimportAny,Dict,List,Optional,Tuple,Union
23
importnumpyasnp
34
fromonnximportTensorProto
@@ -16,6 +17,26 @@
1617
from ._op_varsimportOpsVars
1718

1819

20+
classSubDomain:
21+
"""
22+
Declares a domain or a piece of it (if it contains '.' in its name).
23+
"""
24+
25+
def__init__(self,var:"BaseVar"):
26+
ifnotisinstance(var,BaseVar):
27+
raiseTypeError(f"Unexpected type{type(var)}.")
28+
self.parent=var
29+
30+
31+
def_getclassattr_(self,name):
32+
ifnothasattr(self.__class__,name):
33+
raiseTypeError(
34+
f"Unable to find{name!r} in class{self.__class__.__name__!r}, "
35+
f"available{dir(self.__class__)}."
36+
)
37+
returngetattr(self.__class__,name)
38+
39+
1940
classBaseVar:
2041
"""
2142
Represents an input, an initializer, a node, an output,
@@ -24,6 +45,83 @@ class BaseVar:
2445
:param parent: the graph containing the Variable
2546
"""
2647

48+
def__new__(cls,*args,**kwargs):
49+
res=super().__new__(cls)
50+
res.__init__(*args,**kwargs)
51+
ifgetattr(cls,"__incomplete",True):
52+
forkindir(cls):
53+
att=getattr(cls,k,None)
54+
ifnotatt:
55+
continue
56+
name=getattr(att,"__name__",None)
57+
ifnotnameorname[0]!="[":
58+
continue
59+
60+
# A function with a domain name
61+
ifnotinspect.isfunction(att):
62+
raiseRuntimeError(f"{cls.__name__}.{k} is not a function.")
63+
domain,op_type=name[1:].split("]")
64+
if"."indomain:
65+
spl=domain.split(".",maxsplit=1)
66+
dname=f"_{spl[0]}"
67+
ifnothasattr(cls,dname):
68+
d=type(
69+
f"{cls.__name__}{dname}", (SubDomain,), {"name":dname[1:]}
70+
)
71+
setattr(cls,dname,d)
72+
setattr(
73+
cls,
74+
spl[0],
75+
property(
76+
lambdaself,_name_=dname:_getclassattr_(self,_name_)(
77+
self
78+
)
79+
),
80+
)
81+
else:
82+
d=getattr(cls,dname)
83+
suffix=spl[0]
84+
forpinspl[1].split("."):
85+
dname=f"_{p}"
86+
suffix+=dname
87+
ifnothasattr(d,dname):
88+
sd=type(
89+
f"{cls.__name__}_{suffix}",
90+
(SubDomain,),
91+
{"name":suffix},
92+
)
93+
setattr(d,dname,sd)
94+
setattr(
95+
d,
96+
p,
97+
property(
98+
lambdaself,_name_=dname:_getclassattr_(
99+
self,_name_
100+
)(self.parent)
101+
),
102+
)
103+
d=sd
104+
else:
105+
d=getattr(d,dname)
106+
elifnothasattr(cls,domain):
107+
dname=f"_{domain}"
108+
d=type(f"{cls.__name__}{dname}", (SubDomain,), {"name":domain})
109+
setattr(cls,dname,d)
110+
setattr(
111+
cls,
112+
domain,
113+
property(
114+
lambdaself,_name_=dname:_getclassattr_(self,_name_)(
115+
self
116+
)
117+
),
118+
)
119+
120+
setattr(d,op_type,att)
121+
setattr(cls,"__incomplete",False)
122+
123+
returnres
124+
27125
def__init__(
28126
self,
29127
parent:OnnxGraph,

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp