Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitc82f9f3

Browse files
authored
Supports function full for the Array API (#21)
* Supports function full for the Array API* improvments* fix keys by adding types* fix unit tests* ci
1 parentce37364 commitc82f9f3

17 files changed

+175
-44
lines changed

‎_unittests/onnx-numpy-skips.txt‎

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ array_api_tests/test_creation_functions.py::test_asarray_arrays
55
array_api_tests/test_creation_functions.py::test_empty
66
array_api_tests/test_creation_functions.py::test_empty_like
77
array_api_tests/test_creation_functions.py::test_eye
8-
array_api_tests/test_creation_functions.py::test_full
98
array_api_tests/test_creation_functions.py::test_full_like
109
array_api_tests/test_creation_functions.py::test_linspace
1110
array_api_tests/test_creation_functions.py::test_meshgrid

‎_unittests/test_array_api.sh‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
2-
#pytest ../array-api-tests/array_api_tests/test_creation_functions.py::test_asarray_arrays || exit 1
2+
pytest ../array-api-tests/array_api_tests/test_creation_functions.py::test_asarray_scalars||exit 1
33
# pytest ../array-api-tests/array_api_tests/test_creation_functions.py --help
44
pytest ../array-api-tests/array_api_tests/test_creation_functions.py --hypothesis-explain --skips-file=_unittests/onnx-numpy-skips.txt||exit 1

‎_unittests/ut_array_api/test_array_apis.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class TestArraysApis(ExtTestCase):
1313
deftest_zeros_numpy_1(self):
1414
c=xpn.zeros(1)
1515
d=c.numpy()
16-
self.assertEqualArray(np.array([0],dtype=np.float32),d)
16+
self.assertEqualArray(np.array([0],dtype=np.float64),d)
1717

1818
deftest_zeros_ort_1(self):
1919
c=xpo.zeros(1)

‎_unittests/ut_array_api/test_onnx_numpy.py‎

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,40 @@ def test_zeros(self):
1919
a=xp.absolute(mat)
2020
self.assertEqualArray(np.absolute(mat.numpy()),a.numpy())
2121

22+
deftest_zeros_none(self):
23+
c=EagerTensor(np.array([4,5],dtype=np.int64))
24+
mat=xp.zeros(c)
25+
matnp=mat.numpy()
26+
self.assertEqual(matnp.shape, (4,5))
27+
self.assertNotEmpty(matnp[0,0])
28+
self.assertEqualArray(matnp,np.zeros((4,5)))
29+
30+
deftest_ones_none(self):
31+
c=EagerTensor(np.array([4,5],dtype=np.int64))
32+
mat=xp.ones(c)
33+
matnp=mat.numpy()
34+
self.assertEqual(matnp.shape, (4,5))
35+
self.assertNotEmpty(matnp[0,0])
36+
self.assertEqualArray(matnp,np.ones((4,5)))
37+
38+
deftest_full(self):
39+
c=EagerTensor(np.array([4,5],dtype=np.int64))
40+
mat=xp.full(c,fill_value=5,dtype=xp.int64)
41+
matnp=mat.numpy()
42+
self.assertEqual(matnp.shape, (4,5))
43+
self.assertNotEmpty(matnp[0,0])
44+
a=xp.absolute(mat)
45+
self.assertEqualArray(np.absolute(mat.numpy()),a.numpy())
46+
47+
deftest_full_bool(self):
48+
c=EagerTensor(np.array([4,5],dtype=np.int64))
49+
mat=xp.full(c,fill_value=False)
50+
matnp=mat.numpy()
51+
self.assertEqual(matnp.shape, (4,5))
52+
self.assertNotEmpty(matnp[0,0])
53+
self.assertEqualArray(matnp,np.full((4,5),False))
54+
2255

2356
if__name__=="__main__":
57+
TestOnnxNumpy().test_zeros_none()
2458
unittest.main(verbosity=2)

‎_unittests/ut_npx/test_npx.py‎

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -710,8 +710,8 @@ def impl(
710710
keys=list(sorted(f.onxs))
711711
self.assertIsInstance(f.onxs[keys[0]],ModelProto)
712712
k=keys[-1]
713-
self.assertEqual(len(k),3)
714-
self.assertEqual(k[1:], ("axis",0))
713+
self.assertEqual(len(k),4)
714+
self.assertEqual(k[1:], ("axis",int,0))
715715

716716
deftest_numpy_topk(self):
717717
f=topk(Input("X"),Input("K"))
@@ -2416,6 +2416,7 @@ def compute_labels(X, centers, use_sqrt=False):
24162416
(DType(TensorProto.DOUBLE),2),
24172417
(DType(TensorProto.DOUBLE),2),
24182418
"use_sqrt",
2419+
bool,
24192420
True,
24202421
)
24212422
self.assertEqual(f.available_versions, [key])

‎azure-pipelines.yml‎

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ jobs:
4848
vmImage:'ubuntu-latest'
4949
strategy:
5050
matrix:
51-
Python310-Linux:
51+
Python311-Linux:
5252
python.version:'3.11'
5353
maxParallel:3
5454

@@ -96,7 +96,7 @@ jobs:
9696
strategy:
9797
matrix:
9898
Python310-Linux:
99-
python.version:'3.11'
99+
python.version:'3.10'
100100
maxParallel:3
101101

102102
steps:
@@ -149,7 +149,7 @@ jobs:
149149
vmImage:'ubuntu-latest'
150150
strategy:
151151
matrix:
152-
Python310-Linux:
152+
Python311-Linux:
153153
python.version:'3.11'
154154
maxParallel:3
155155

@@ -202,7 +202,7 @@ jobs:
202202
vmImage:'windows-latest'
203203
strategy:
204204
matrix:
205-
Python310-Windows:
205+
Python311-Windows:
206206
python.version:'3.11'
207207
maxParallel:3
208208

@@ -235,7 +235,7 @@ jobs:
235235
vmImage:'macOS-latest'
236236
strategy:
237237
matrix:
238-
Python310-Mac:
238+
Python311-Mac:
239239
python.version:'3.11'
240240
maxParallel:3
241241

‎onnx_array_api/_helpers.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def np_dtype_to_tensor_dtype(dtype: Any):
3939
elifdtypeisint:
4040
dt=TensorProto.INT64
4141
elifdtypeisfloat:
42-
dt=TensorProto.FLOAT64
42+
dt=TensorProto.DOUBLE
4343
else:
4444
raiseKeyError(f"Unable to guess type for dtype={dtype}.")
4545
returndt

‎onnx_array_api/array_api/_onnx_common.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def template_asarray(
4444
exceptOverflowError:
4545
v=TEagerTensor(np.asarray(a,dtype=np.uint64))
4646
elifisinstance(a,float):
47-
v=TEagerTensor(np.array(a,dtype=np.float32))
47+
v=TEagerTensor(np.array(a,dtype=np.float64))
4848
elifisinstance(a,bool):
4949
v=TEagerTensor(np.array(a,dtype=np.bool_))
5050
elifisinstance(a,str):

‎onnx_array_api/array_api/onnx_numpy.py‎

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
"""
44
fromtypingimportAny,Optional
55
importnumpyasnp
6-
fromonnximportTensorProto
76
from ..npx.npx_functionsimport (
87
all,
98
abs,
@@ -16,10 +15,11 @@
1615
reshape,
1716
take,
1817
)
18+
from ..npx.npx_functionsimportfullasgeneric_full
1919
from ..npx.npx_functionsimportonesasgeneric_ones
2020
from ..npx.npx_functionsimportzerosasgeneric_zeros
2121
from ..npx.npx_numpy_tensorsimportEagerNumpyTensor
22-
from ..npx.npx_typesimportDType,ElemType,TensorType,OptParType
22+
from ..npx.npx_typesimportDType,ElemType,TensorType,OptParType,ParType,Scalar
2323
from ._onnx_commonimporttemplate_asarray
2424
from .import_finalize_array_api
2525

@@ -31,6 +31,7 @@
3131
"astype",
3232
"empty",
3333
"equal",
34+
"full",
3435
"isdtype",
3536
"isfinite",
3637
"isnan",
@@ -58,7 +59,7 @@ def asarray(
5859

5960
defones(
6061
shape:TensorType[ElemType.int64,"I", (None,)],
61-
dtype:OptParType[DType]=DType(TensorProto.FLOAT),
62+
dtype:OptParType[DType]=None,
6263
order:OptParType[str]="C",
6364
)->TensorType[ElemType.numerics,"T"]:
6465
ifisinstance(shape,tuple):
@@ -76,7 +77,7 @@ def ones(
7677

7778
defempty(
7879
shape:TensorType[ElemType.int64,"I", (None,)],
79-
dtype:OptParType[DType]=DType(TensorProto.FLOAT),
80+
dtype:OptParType[DType]=None,
8081
order:OptParType[str]="C",
8182
)->TensorType[ElemType.numerics,"T"]:
8283
raiseRuntimeError(
@@ -87,7 +88,7 @@ def empty(
8788

8889
defzeros(
8990
shape:TensorType[ElemType.int64,"I", (None,)],
90-
dtype:OptParType[DType]=DType(TensorProto.FLOAT),
91+
dtype:OptParType[DType]=None,
9192
order:OptParType[str]="C",
9293
)->TensorType[ElemType.numerics,"T"]:
9394
ifisinstance(shape,tuple):
@@ -103,6 +104,32 @@ def zeros(
103104
returngeneric_zeros(shape,dtype=dtype,order=order)
104105

105106

107+
deffull(
108+
shape:TensorType[ElemType.int64,"I", (None,)],
109+
fill_value:ParType[Scalar]=None,
110+
dtype:OptParType[DType]=None,
111+
order:OptParType[str]="C",
112+
)->TensorType[ElemType.numerics,"T"]:
113+
iffill_valueisNone:
114+
raiseTypeError("fill_value cannot be None")
115+
value=fill_value
116+
ifisinstance(shape,tuple):
117+
returngeneric_full(
118+
EagerNumpyTensor(np.array(shape,dtype=np.int64)),
119+
fill_value=value,
120+
dtype=dtype,
121+
order=order,
122+
)
123+
ifisinstance(shape,int):
124+
returngeneric_full(
125+
EagerNumpyTensor(np.array([shape],dtype=np.int64)),
126+
fill_value=value,
127+
dtype=dtype,
128+
order=order,
129+
)
130+
returngeneric_full(shape,fill_value=value,dtype=dtype,order=order)
131+
132+
106133
def_finalize():
107134
"""
108135
Adds common attributes to Array API defined in this modules

‎onnx_array_api/npx/npx_core_api.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def wrapper(*inputs, **kwargs):
169169
new_inputs.append(i)
170170
elifisinstance(i, (int,float)):
171171
new_inputs.append(
172-
np.array([i],dtype=np.int64ifisinstance(i,int)elsenp.float32)
172+
np.array([i],dtype=np.int64ifisinstance(i,int)elsenp.float64)
173173
)
174174
elifisinstance(i,str):
175175
new_inputs.append(Input(i))

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp