Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit7675869

Browse files
authored
Extend ExtendedReferenceEvaluator (#75)
* update requirements* add more operator to the reference evaluator* extend unit test copverage
1 parenta070da3 commit7675869

File tree

6 files changed

+222
-0
lines changed

6 files changed

+222
-0
lines changed

‎CHANGELOGS.rst‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.2.0
55
+++++
66

7+
*:pr:`75`: add QuickGelu to ExtendedReferenceEvaluator
78
*:pr:`71`: adds tools to compare two onnx graphs
89
*:pr:`61`: adds function to plot onnx model as graphs
910
*:pr:`60`: supports translation of local functions

‎_unittests/ut_reference/test_reference_ops.py‎

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,88 @@ def test_fused_matmul11(self):
5959
got=ref.run(None, {"X":a,"Y":a})
6060
self.assertEqualArray(a.T @a.T,got[0])
6161

62+
deftest_memcpy(self):
63+
model=make_model(
64+
make_graph(
65+
[
66+
make_node("MemcpyToHost", ["X"], ["Z"]),
67+
make_node("MemcpyFromHost", ["X"], ["Z"]),
68+
],
69+
"name",
70+
[make_tensor_value_info("X",TensorProto.FLOAT,None)],
71+
[make_tensor_value_info("Z",TensorProto.FLOAT,None)],
72+
),
73+
opset_imports=[make_opsetid("",18),make_opsetid("com.microsoft",1)],
74+
ir_version=9,
75+
)
76+
a=np.arange(4).reshape(-1,2).astype(np.float32)
77+
ref=ExtendedReferenceEvaluator(model)
78+
got=ref.run(None, {"X":a})
79+
self.assertEqualArray(a,got[0])
80+
81+
deftest_quick_gelu(self):
82+
fromonnxruntimeimportInferenceSession
83+
84+
foralphain [0.0,2.0]:
85+
model=make_model(
86+
make_graph(
87+
[
88+
make_node(
89+
"QuickGelu",
90+
["X"],
91+
["Z"],
92+
domain="com.microsoft",
93+
alpha=alpha,
94+
)
95+
],
96+
"name",
97+
[make_tensor_value_info("X",TensorProto.FLOAT,None)],
98+
[make_tensor_value_info("Z",TensorProto.FLOAT,None)],
99+
),
100+
opset_imports=[make_opsetid("",18),make_opsetid("com.microsoft",1)],
101+
ir_version=9,
102+
)
103+
sess=InferenceSession(
104+
model.SerializeToString(),providers=["CPUExecutionProvider"]
105+
)
106+
a=np.arange(4).reshape(-1,2).astype(np.float32)
107+
expected=sess.run(None, {"X":a})
108+
ref=ExtendedReferenceEvaluator(model)
109+
got=ref.run(None, {"X":a})
110+
self.assertEqualArray(expected[0],got[0])
111+
112+
deftest_scatter_elements(self):
113+
model=make_model(
114+
make_graph(
115+
[
116+
make_node(
117+
"ScatterElements",
118+
["data","indices","updates"],
119+
["Z"],
120+
axis=3,
121+
reduction="add",
122+
)
123+
],
124+
"name",
125+
[
126+
make_tensor_value_info("data",TensorProto.FLOAT,None),
127+
make_tensor_value_info("indices",TensorProto.INT64,None),
128+
make_tensor_value_info("updates",TensorProto.FLOAT,None),
129+
],
130+
[make_tensor_value_info("Z",TensorProto.FLOAT,None)],
131+
),
132+
opset_imports=[make_opsetid("",18)],
133+
)
134+
data=np.zeros(2**4,dtype=np.float32).reshape((2,2,2,2))
135+
indices=np.array([[[[0]]]],dtype=np.int64)
136+
updates=np.array([[[[1]]]],dtype=np.float32)
137+
y=np.array(
138+
[1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],dtype=np.float32
139+
).reshape((2,2,2,2))
140+
ref=ExtendedReferenceEvaluator(model)
141+
got=ref.run(None, {"data":data,"indices":indices,"updates":updates})
142+
self.assertEqualArray(y,got[0])
143+
62144

63145
if__name__=="__main__":
64146
unittest.main(verbosity=2)

‎onnx_array_api/reference/evaluator.py‎

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
from .ops.op_concatimportConcat
99
from .ops.op_constant_of_shapeimportConstantOfShape
1010
from .ops.op_fused_matmulimportFusedMatMul
11+
from .ops.op_memcpy_hostimportMemcpyFromHost,MemcpyToHost
12+
from .ops.op_quick_geluimportQuickGelu
13+
from .ops.op_scatter_elementsimportScatterElements
1114

1215

1316
logger=getLogger("onnx-array-api-eval")
@@ -34,6 +37,10 @@ class ExtendedReferenceEvaluator(ReferenceEvaluator):
3437
CastLike_19,
3538
ConstantOfShape,
3639
FusedMatMul,
40+
MemcpyFromHost,
41+
MemcpyToHost,
42+
QuickGelu,
43+
ScatterElements,
3744
]
3845

3946
@staticmethod
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
fromonnx.reference.op_runimportOpRun
2+
3+
4+
classMemcpyFromHost(OpRun):
5+
def_run(self,x):
6+
return (x,)
7+
8+
9+
classMemcpyToHost(OpRun):
10+
def_run(self,x):
11+
return (x,)
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
importnumpyasnp
2+
fromonnx.reference.op_runimportOpRun
3+
4+
5+
defsigmoid(x):# type: ignore
6+
ifx>0:
7+
return1/ (1+np.exp(-x))
8+
returnnp.exp(x)/ (1+np.exp(x))
9+
10+
11+
classQuickGelu(OpRun):
12+
op_domain="com.microsoft"
13+
14+
def__init__(self,onnx_node,run_params):# type: ignore
15+
OpRun.__init__(self,onnx_node,run_params)
16+
self.vf=np.vectorize(sigmoid)
17+
18+
def_run(self,X,alpha=1.0):
19+
iflen(X.shape)==0:
20+
return ((X*sigmoid(X*alpha)).astype(X.dtype),)
21+
ifX.size==0:
22+
return (X,)
23+
return ((X*self.vf(X*alpha)).astype(X.dtype),)
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
importnumpyasnp
2+
3+
fromonnx.reference.op_runimportOpRun
4+
5+
6+
defscatter_elements(data,indices,updates,axis=0,reduction=None):# type: ignore
7+
ifreduction=="add":
8+
9+
deff(x,y):
10+
returnx+y
11+
12+
elifreduction=="min":
13+
14+
deff(x,y):
15+
returnmin(x,y)
16+
17+
elifreduction=="max":
18+
19+
deff(x,y):
20+
returnmax(x,y)
21+
22+
else:
23+
24+
deff(x,y):
25+
returny
26+
27+
ifaxis<0:
28+
axis=data.ndim+axis
29+
30+
iflen(data.shape)==1andaxis==0:
31+
scattered=np.copy(data)
32+
forpos,upinzip(indices,updates):
33+
scattered[pos]=f(scattered[pos],up)
34+
returnscattered
35+
36+
iflen(indices.shape)==2:
37+
scattered=np.copy(data)
38+
ifaxis==0:
39+
foriinrange(indices.shape[0]):
40+
forjinrange(indices.shape[1]):
41+
scattered[indices[i,j],j]=f(
42+
scattered[indices[i,j],j],updates[i,j]
43+
)
44+
else:
45+
foriinrange(indices.shape[0]):
46+
forjinrange(indices.shape[1]):
47+
scattered[i,indices[i,j]]=f(
48+
scattered[i,indices[i,j]],updates[i,j]
49+
)
50+
returnscattered
51+
52+
iflen(indices.shape)==3:
53+
scattered=np.copy(data)
54+
ifaxis==0:
55+
foriinrange(indices.shape[0]):
56+
forjinrange(indices.shape[1]):
57+
forkinrange(indices.shape[2]):
58+
scattered[indices[i,j,k],j,k]=f(
59+
scattered[indices[i,j,k],j,k],updates[i,j,k]
60+
)
61+
elifaxis==1:
62+
foriinrange(indices.shape[0]):
63+
forjinrange(indices.shape[1]):
64+
forkinrange(indices.shape[2]):
65+
scattered[i,indices[i,j,k],k]=f(
66+
scattered[i,indices[i,j,k],k],updates[i,j,k]
67+
)
68+
elifaxis==2:
69+
foriinrange(indices.shape[0]):
70+
forjinrange(indices.shape[1]):
71+
forkinrange(indices.shape[2]):
72+
scattered[i,j,indices[i,j,k]]=f(
73+
scattered[i,j,indices[i,j,k]],updates[i,j,k]
74+
)
75+
returnscattered
76+
77+
iflen(indices.shape)==4:
78+
scattered=np.copy(data)
79+
ifaxis==3:
80+
forainrange(indices.shape[0]):
81+
foriinrange(indices.shape[1]):
82+
forjinrange(indices.shape[2]):
83+
forkinrange(indices.shape[3]):
84+
scattered[a,i,j,indices[a,i,j,k]]=f(
85+
scattered[a,i,j,indices[a,i,j,k]],
86+
updates[a,i,j,k],
87+
)
88+
returnscattered
89+
90+
raiseRuntimeError(
91+
f"Not implemented for indices.shape={indices.shape} and axis={axis}"
92+
)
93+
94+
95+
classScatterElements(OpRun):
96+
def_run(self,data,indices,updates,axis=None,reduction=None):# type: ignore
97+
res=scatter_elements(data,indices,updates,axis=axis,reduction=reduction)
98+
return (res,)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp