Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit5edccab

Browse files
authored
Extends Array API to EagerOrt (#18)
* Extends Array API to EagerOrt* fix empty shape* fix shape* fix azure* refactoring* fix command line* CI* fix CI* fix CI
1 parent37fe094 commit5edccab

21 files changed

+444
-93
lines changed

‎.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ _cache/*
88
dist/*
99
build/*
1010
.eggs/*
11+
.hypothesis/*
1112
*egg-info/*
1213
_doc/auto_examples/*
1314
_doc/examples/_cache/*

‎_unittests/onnx-numpy-skips.txt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# API failures
2+
array_api_tests/test_creation_functions.py::test_arange
3+
array_api_tests/test_creation_functions.py::test_asarray_scalars
4+
array_api_tests/test_creation_functions.py::test_asarray_arrays
5+
array_api_tests/test_creation_functions.py::test_empty
6+
array_api_tests/test_creation_functions.py::test_empty_like
7+
array_api_tests/test_creation_functions.py::test_eye
8+
array_api_tests/test_creation_functions.py::test_full
9+
array_api_tests/test_creation_functions.py::test_full_like
10+
array_api_tests/test_creation_functions.py::test_linspace
11+
array_api_tests/test_creation_functions.py::test_meshgrid
12+
array_api_tests/test_creation_functions.py::test_ones_like
13+
array_api_tests/test_creation_functions.py::test_zeros_like

‎_unittests/onnx-ort-skips.txt

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Not implementated by onnxruntime
2+
array_api_tests/test_creation_functions.py::test_arange
3+
array_api_tests/test_creation_functions.py::test_asarray_scalars
4+
array_api_tests/test_creation_functions.py::test_asarray_arrays
5+
array_api_tests/test_creation_functions.py::test_empty
6+
array_api_tests/test_creation_functions.py::test_empty_like
7+
array_api_tests/test_creation_functions.py::test_eye
8+
array_api_tests/test_creation_functions.py::test_full
9+
array_api_tests/test_creation_functions.py::test_full_like
10+
array_api_tests/test_creation_functions.py::test_linspace
11+
array_api_tests/test_creation_functions.py::test_meshgrid
12+
array_api_tests/test_creation_functions.py::test_ones
13+
array_api_tests/test_creation_functions.py::test_ones_like
14+
array_api_tests/test_creation_functions.py::test_zeros
15+
array_api_tests/test_creation_functions.py::test_zeros_like
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
importunittest
2+
frominspectimportisfunction,ismethod
3+
importnumpyasnp
4+
fromonnx_array_api.ext_test_caseimportExtTestCase
5+
fromonnx_array_api.array_apiimportonnx_numpyasxpn
6+
fromonnx_array_api.array_apiimportonnx_ortasxpo
7+
8+
# from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor
9+
# from onnx_array_api.ort.ort_tensors import EagerOrtTensor
10+
11+
12+
classTestArraysApis(ExtTestCase):
13+
deftest_zeros_numpy_1(self):
14+
c=xpn.zeros(1)
15+
d=c.numpy()
16+
self.assertEqualArray(np.array([0],dtype=np.float32),d)
17+
18+
deftest_zeros_ort_1(self):
19+
c=xpo.zeros(1)
20+
d=c.numpy()
21+
self.assertEqualArray(np.array([0],dtype=np.float32),d)
22+
23+
deftest_ffinfo(self):
24+
dt=np.float32
25+
fi1=np.finfo(dt)
26+
fi2=xpn.finfo(dt)
27+
fi3=xpo.finfo(dt)
28+
dt1=fi1.dtype
29+
dt2=fi2.dtype
30+
dt3=fi3.dtype
31+
self.assertEqual(dt2,dt3)
32+
self.assertNotEqual(dt1.__class__,dt2.__class__)
33+
mi1=fi1.min
34+
mi2=fi2.min
35+
self.assertEqual(mi1,mi2)
36+
mi1=fi1.smallest_normal
37+
mi2=fi2.smallest_normal
38+
self.assertEqual(mi1,mi2)
39+
fornindir(fi1):
40+
ifn.startswith("__"):
41+
continue
42+
ifnin {"machar"}:
43+
continue
44+
v1=getattr(fi1,n)
45+
withself.subTest(att=n):
46+
v2=getattr(fi2,n)
47+
v3=getattr(fi3,n)
48+
ifisfunction(v1)orismethod(v1):
49+
try:
50+
v1=v1()
51+
exceptTypeError:
52+
continue
53+
v2=v2()
54+
v3=v3()
55+
ifv1!=v2:
56+
raiseAssertionError(
57+
f"12: info disagree on name{n!r}:{v1} !={v2}, "
58+
f"type(v1)={type(v1)}, type(v2)={type(v2)}, "
59+
f"ismethod={ismethod(v1)}."
60+
)
61+
ifv2!=v3:
62+
raiseAssertionError(
63+
f"23: info disagree on name{n!r}:{v2} !={v3}, "
64+
f"type(v1)={type(v1)}, type(v2)={type(v2)}, "
65+
f"ismethod={ismethod(v1)}."
66+
)
67+
68+
deftest_iiinfo(self):
69+
dt=np.int64
70+
fi1=np.iinfo(dt)
71+
fi2=xpn.iinfo(dt)
72+
fi3=xpo.iinfo(dt)
73+
dt1=fi1.dtype
74+
dt2=fi2.dtype
75+
dt3=fi3.dtype
76+
self.assertEqual(dt2,dt3)
77+
self.assertNotEqual(dt1.__class__,dt2.__class__)
78+
mi1=fi1.min
79+
mi2=fi2.min
80+
self.assertEqual(mi1,mi2)
81+
fornindir(fi1):
82+
ifn.startswith("__"):
83+
continue
84+
ifnin {"machar"}:
85+
continue
86+
v1=getattr(fi1,n)
87+
withself.subTest(att=n):
88+
v2=getattr(fi2,n)
89+
v3=getattr(fi3,n)
90+
ifisfunction(v1)orismethod(v1):
91+
try:
92+
v1=v1()
93+
exceptTypeError:
94+
continue
95+
v2=v2()
96+
v3=v3()
97+
ifv1!=v2:
98+
raiseAssertionError(
99+
f"12: info disagree on name{n!r}:{v1} !={v2}, "
100+
f"type(v1)={type(v1)}, type(v2)={type(v2)}, "
101+
f"ismethod={ismethod(v1)}."
102+
)
103+
ifv2!=v3:
104+
raiseAssertionError(
105+
f"23: info disagree on name{n!r}:{v2} !={v3}, "
106+
f"type(v1)={type(v1)}, type(v2)={type(v2)}, "
107+
f"ismethod={ismethod(v1)}."
108+
)
109+
110+
111+
if__name__=="__main__":
112+
unittest.main(verbosity=2)

‎_unittests/ut_array_api/test_onnx_numpy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
importnumpyasnp
33
fromonnx_array_api.ext_test_caseimportExtTestCase
44
fromonnx_array_api.array_apiimportonnx_numpyasxp
5-
fromonnx_array_api.npx.npx_numpy_tensorsimportEagerNumpyTensor
5+
fromonnx_array_api.npx.npx_numpy_tensorsimportEagerNumpyTensorasEagerTensor
66

77

88
classTestOnnxNumpy(ExtTestCase):
99
deftest_abs(self):
10-
c=EagerNumpyTensor(np.array([4,5],dtype=np.int64))
10+
c=EagerTensor(np.array([4,5],dtype=np.int64))
1111
mat=xp.zeros(c,dtype=xp.int64)
1212
matnp=mat.numpy()
1313
self.assertEqual(matnp.shape, (4,5))
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
importunittest
2+
importnumpyasnp
3+
fromonnx_array_api.ext_test_caseimportExtTestCase
4+
fromonnx_array_api.array_apiimportonnx_ortasxp
5+
fromonnx_array_api.ort.ort_tensorsimportEagerOrtTensorasEagerTensor
6+
7+
8+
classTestOnnxOrt(ExtTestCase):
9+
deftest_abs(self):
10+
c=EagerTensor(np.array([4,5],dtype=np.int64))
11+
mat=xp.zeros(c,dtype=xp.int64)
12+
matnp=mat.numpy()
13+
self.assertEqual(matnp.shape, (4,5))
14+
self.assertNotEmpty(matnp[0,0])
15+
a=xp.absolute(mat)
16+
self.assertEqualArray(np.absolute(mat.numpy()),a.numpy())
17+
18+
19+
if__name__=="__main__":
20+
unittest.main(verbosity=2)

‎_unittests/ut_ort/test_ort_tensor.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
11
importunittest
22
fromcontextlibimportredirect_stdout
33
fromioimportStringIO
4-
54
importnumpyasnp
65
fromonnx.defsimportonnx_opset_version
76
fromonnx.referenceimportReferenceEvaluator
87
fromonnxruntimeimportInferenceSession
9-
108
fromonnx_array_api.ext_test_caseimportExtTestCase
119
fromonnx_array_api.npximporteager_onnx,jit_onnx
1210
fromonnx_array_api.npx.npx_functionsimportabsoluteasabsolute_inline
1311
fromonnx_array_api.npx.npx_functionsimportcdistascdist_inline
1412
fromonnx_array_api.npx.npx_functions_testimportabsolute
15-
fromonnx_array_api.npx.npx_typesimportFloat32,Float64
13+
fromonnx_array_api.npx.npx_functionsimportcopyascopy_inline
14+
fromonnx_array_api.npx.npx_typesimportFloat32,Float64,DType
1615
fromonnx_array_api.npx.npx_varimportInput
1716
fromonnx_array_api.ort.ort_tensorsimportEagerOrtTensor,JitOrtTensor,OrtTensor
1817

@@ -193,6 +192,49 @@ def impl(xa, xb):
193192
iflen(pieces)>2:
194193
raiseAssertionError(f"Function is not using argument:\n{onx}")
195194

195+
deftest_astype(self):
196+
f=absolute_inline(copy_inline(Input("A")).astype(np.float32))
197+
onx=f.to_onnx(constraints={"A":Float64[None]})
198+
x=np.array([[-5,6]],dtype=np.float64)
199+
z=np.abs(x.astype(np.float32))
200+
ref=InferenceSession(
201+
onx.SerializeToString(),providers=["CPUExecutionProvider"]
202+
)
203+
got=ref.run(None, {"A":x})
204+
self.assertEqualArray(z,got[0])
205+
206+
deftest_astype0(self):
207+
f=absolute_inline(copy_inline(Input("A")).astype(np.float32))
208+
onx=f.to_onnx(constraints={"A":Float64[None]})
209+
x=np.array(-5,dtype=np.float64)
210+
z=np.abs(x.astype(np.float32))
211+
ref=InferenceSession(
212+
onx.SerializeToString(),providers=["CPUExecutionProvider"]
213+
)
214+
got=ref.run(None, {"A":x})
215+
self.assertEqualArray(z,got[0])
216+
217+
deftest_eager_ort_cast(self):
218+
defimpl(A):
219+
returnA.astype(DType("FLOAT"))
220+
221+
e=eager_onnx(impl)
222+
self.assertEqual(len(e.versions),0)
223+
224+
# Float64
225+
x=np.array([0,1,-2],dtype=np.float64)
226+
z=x.astype(np.float32)
227+
res=e(x)
228+
self.assertEqualArray(z,res)
229+
self.assertEqual(res.dtype,np.float32)
230+
231+
# again
232+
x=np.array(1,dtype=np.float64)
233+
z=x.astype(np.float32)
234+
res=e(x)
235+
self.assertEqualArray(z,res)
236+
self.assertEqual(res.dtype,np.float32)
237+
196238

197239
if__name__=="__main__":
198240
# TestNpx().test_eager_numpy()

‎azure-pipelines.yml

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ jobs:
110110
displayName:'Install tools'
111111
-script:pip install -r requirements.txt
112112
displayName:'Install Requirements'
113+
-script:pip install onnxruntime
114+
displayName:'Install onnxruntime'
113115
-script:python setup.py install
114116
displayName:'Install onnx_array_api'
115117
-script:|
@@ -129,8 +131,13 @@ jobs:
129131
-script:|
130132
export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
131133
cd array-api-tests
132-
python -m pytest -x array_api_tests/test_creation_functions.py::test_zeros
133-
displayName: "test_creation_functions.py::test_zeros"
134+
python -m pytest -x array_api_tests/test_creation_functions.py --skips-file=../_unittests/onnx-numpy-skips.txt -v
135+
displayName: "numpy test_creation_functions.py"
136+
-script:|
137+
export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_ort
138+
cd array-api-tests
139+
python -m pytest -x array_api_tests/test_creation_functions.py --skips-file=../_unittests/onnx-ort-skips.txt -v
140+
displayName: "ort test_creation_functions.py"
134141
#- script: |
135142
# export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
136143
# cd array-api-tests
@@ -246,16 +253,8 @@ jobs:
246253
displayName:'export'
247254
-script:gcc --version
248255
displayName:'gcc version'
249-
-script:brew install llvm
250-
displayName:'install llvm'
251-
-script:brew install libomp
252-
displayName:'Install omp'
253-
-script:brew install p7zip
254-
displayName:'Install p7zip'
255256
-script:python -m pip install --upgrade pip setuptools wheel
256257
displayName:'Install tools'
257-
-script:brew install pybind11
258-
displayName:'Install pybind11'
259258
-script:pip install -r requirements.txt
260259
displayName:'Install Requirements'
261260
-script:pip install -r requirements-dev.txt

‎onnx_array_api/_helpers.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
importnumpyasnp
2+
fromtypingimportAny
3+
fromonnximporthelper,TensorProto
4+
5+
6+
defnp_dtype_to_tensor_dtype(dtype:Any):
7+
"""
8+
Improves :func:`onnx.helper.np_dtype_to_tensor_dtype`.
9+
"""
10+
try:
11+
dt=helper.np_dtype_to_tensor_dtype(dtype)
12+
exceptKeyError:
13+
ifdtype==np.float32:
14+
dt=TensorProto.FLOAT
15+
elifdtype==np.float64:
16+
dt=TensorProto.DOUBLE
17+
elifdtype==np.int64:
18+
dt=TensorProto.INT64
19+
elifdtype==np.int32:
20+
dt=TensorProto.INT32
21+
elifdtype==np.int16:
22+
dt=TensorProto.INT16
23+
elifdtype==np.int8:
24+
dt=TensorProto.INT8
25+
elifdtype==np.uint64:
26+
dt=TensorProto.UINT64
27+
elifdtype==np.uint32:
28+
dt=TensorProto.UINT32
29+
elifdtype==np.uint16:
30+
dt=TensorProto.UINT16
31+
elifdtype==np.uint8:
32+
dt=TensorProto.UINT8
33+
elifdtype==np.float16:
34+
dt=TensorProto.FLOAT16
35+
elifdtypein (bool,np.bool_):
36+
dt=TensorProto.BOOL
37+
elifdtypein (str,np.str_):
38+
dt=TensorProto.STRING
39+
elifdtypeisint:
40+
dt=TensorProto.INT64
41+
elifdtypeisfloat:
42+
dt=TensorProto.FLOAT64
43+
else:
44+
raiseKeyError(f"Unable to guess type for dtype={dtype}.")
45+
returndt

‎onnx_array_api/array_api/__init__.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,42 @@
1+
importnumpyasnp
12
fromonnximportTensorProto
3+
from .._helpersimportnp_dtype_to_tensor_dtype
24
from ..npx.npx_typesimportDType
35

46

7+
def_finfo(dtype):
8+
"""
9+
Similar to :class:`numpy.finfo`.
10+
"""
11+
dt=dtype.np_dtypeifisinstance(dtype,DType)elsedtype
12+
res=np.finfo(dt)
13+
d=res.__dict__.copy()
14+
d["dtype"]=DType(np_dtype_to_tensor_dtype(dt))
15+
nres=type("finfo", (res.__class__,),d)
16+
setattr(nres,"smallest_normal",res.smallest_normal)
17+
setattr(nres,"tiny",res.tiny)
18+
returnnres
19+
20+
21+
def_iinfo(dtype):
22+
"""
23+
Similar to :class:`numpy.finfo`.
24+
"""
25+
dt=dtype.np_dtypeifisinstance(dtype,DType)elsedtype
26+
res=np.iinfo(dt)
27+
d=res.__dict__.copy()
28+
d["dtype"]=DType(np_dtype_to_tensor_dtype(dt))
29+
nres=type("finfo", (res.__class__,),d)
30+
setattr(nres,"min",res.min)
31+
setattr(nres,"max",res.max)
32+
returnnres
33+
34+
535
def_finalize_array_api(module):
36+
"""
37+
Adds common attributes to Array API defined in this modules
38+
such as types.
39+
"""
640
module.float16=DType(TensorProto.FLOAT16)
741
module.float32=DType(TensorProto.FLOAT)
842
module.float64=DType(TensorProto.DOUBLE)
@@ -17,3 +51,5 @@ def _finalize_array_api(module):
1751
module.bfloat16=DType(TensorProto.BFLOAT16)
1852
setattr(module,"bool",DType(TensorProto.BOOL))
1953
setattr(module,"str",DType(TensorProto.STRING))
54+
setattr(module,"finfo",_finfo)
55+
setattr(module,"iinfo",_iinfo)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp