Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitc82f9f3

Browse files
authored
Supports function full for the Array API (#21)
* Supports function full for the Array API* improvments* fix keys by adding types* fix unit tests* ci
1 parentce37364 commitc82f9f3

17 files changed

+175
-44
lines changed

‎_unittests/onnx-numpy-skips.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ array_api_tests/test_creation_functions.py::test_asarray_arrays
55
array_api_tests/test_creation_functions.py::test_empty
66
array_api_tests/test_creation_functions.py::test_empty_like
77
array_api_tests/test_creation_functions.py::test_eye
8-
array_api_tests/test_creation_functions.py::test_full
98
array_api_tests/test_creation_functions.py::test_full_like
109
array_api_tests/test_creation_functions.py::test_linspace
1110
array_api_tests/test_creation_functions.py::test_meshgrid

‎_unittests/test_array_api.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
2-
#pytest ../array-api-tests/array_api_tests/test_creation_functions.py::test_asarray_arrays || exit 1
2+
pytest ../array-api-tests/array_api_tests/test_creation_functions.py::test_asarray_scalars||exit 1
33
# pytest ../array-api-tests/array_api_tests/test_creation_functions.py --help
44
pytest ../array-api-tests/array_api_tests/test_creation_functions.py --hypothesis-explain --skips-file=_unittests/onnx-numpy-skips.txt||exit 1

‎_unittests/ut_array_api/test_array_apis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class TestArraysApis(ExtTestCase):
1313
deftest_zeros_numpy_1(self):
1414
c=xpn.zeros(1)
1515
d=c.numpy()
16-
self.assertEqualArray(np.array([0],dtype=np.float32),d)
16+
self.assertEqualArray(np.array([0],dtype=np.float64),d)
1717

1818
deftest_zeros_ort_1(self):
1919
c=xpo.zeros(1)

‎_unittests/ut_array_api/test_onnx_numpy.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,40 @@ def test_zeros(self):
1919
a=xp.absolute(mat)
2020
self.assertEqualArray(np.absolute(mat.numpy()),a.numpy())
2121

22+
deftest_zeros_none(self):
23+
c=EagerTensor(np.array([4,5],dtype=np.int64))
24+
mat=xp.zeros(c)
25+
matnp=mat.numpy()
26+
self.assertEqual(matnp.shape, (4,5))
27+
self.assertNotEmpty(matnp[0,0])
28+
self.assertEqualArray(matnp,np.zeros((4,5)))
29+
30+
deftest_ones_none(self):
31+
c=EagerTensor(np.array([4,5],dtype=np.int64))
32+
mat=xp.ones(c)
33+
matnp=mat.numpy()
34+
self.assertEqual(matnp.shape, (4,5))
35+
self.assertNotEmpty(matnp[0,0])
36+
self.assertEqualArray(matnp,np.ones((4,5)))
37+
38+
deftest_full(self):
39+
c=EagerTensor(np.array([4,5],dtype=np.int64))
40+
mat=xp.full(c,fill_value=5,dtype=xp.int64)
41+
matnp=mat.numpy()
42+
self.assertEqual(matnp.shape, (4,5))
43+
self.assertNotEmpty(matnp[0,0])
44+
a=xp.absolute(mat)
45+
self.assertEqualArray(np.absolute(mat.numpy()),a.numpy())
46+
47+
deftest_full_bool(self):
48+
c=EagerTensor(np.array([4,5],dtype=np.int64))
49+
mat=xp.full(c,fill_value=False)
50+
matnp=mat.numpy()
51+
self.assertEqual(matnp.shape, (4,5))
52+
self.assertNotEmpty(matnp[0,0])
53+
self.assertEqualArray(matnp,np.full((4,5),False))
54+
2255

2356
if__name__=="__main__":
57+
TestOnnxNumpy().test_zeros_none()
2458
unittest.main(verbosity=2)

‎_unittests/ut_npx/test_npx.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -710,8 +710,8 @@ def impl(
710710
keys=list(sorted(f.onxs))
711711
self.assertIsInstance(f.onxs[keys[0]],ModelProto)
712712
k=keys[-1]
713-
self.assertEqual(len(k),3)
714-
self.assertEqual(k[1:], ("axis",0))
713+
self.assertEqual(len(k),4)
714+
self.assertEqual(k[1:], ("axis",int,0))
715715

716716
deftest_numpy_topk(self):
717717
f=topk(Input("X"),Input("K"))
@@ -2416,6 +2416,7 @@ def compute_labels(X, centers, use_sqrt=False):
24162416
(DType(TensorProto.DOUBLE),2),
24172417
(DType(TensorProto.DOUBLE),2),
24182418
"use_sqrt",
2419+
bool,
24192420
True,
24202421
)
24212422
self.assertEqual(f.available_versions, [key])

‎azure-pipelines.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ jobs:
4848
vmImage:'ubuntu-latest'
4949
strategy:
5050
matrix:
51-
Python310-Linux:
51+
Python311-Linux:
5252
python.version:'3.11'
5353
maxParallel:3
5454

@@ -96,7 +96,7 @@ jobs:
9696
strategy:
9797
matrix:
9898
Python310-Linux:
99-
python.version:'3.11'
99+
python.version:'3.10'
100100
maxParallel:3
101101

102102
steps:
@@ -149,7 +149,7 @@ jobs:
149149
vmImage:'ubuntu-latest'
150150
strategy:
151151
matrix:
152-
Python310-Linux:
152+
Python311-Linux:
153153
python.version:'3.11'
154154
maxParallel:3
155155

@@ -202,7 +202,7 @@ jobs:
202202
vmImage:'windows-latest'
203203
strategy:
204204
matrix:
205-
Python310-Windows:
205+
Python311-Windows:
206206
python.version:'3.11'
207207
maxParallel:3
208208

@@ -235,7 +235,7 @@ jobs:
235235
vmImage:'macOS-latest'
236236
strategy:
237237
matrix:
238-
Python310-Mac:
238+
Python311-Mac:
239239
python.version:'3.11'
240240
maxParallel:3
241241

‎onnx_array_api/_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def np_dtype_to_tensor_dtype(dtype: Any):
3939
elifdtypeisint:
4040
dt=TensorProto.INT64
4141
elifdtypeisfloat:
42-
dt=TensorProto.FLOAT64
42+
dt=TensorProto.DOUBLE
4343
else:
4444
raiseKeyError(f"Unable to guess type for dtype={dtype}.")
4545
returndt

‎onnx_array_api/array_api/_onnx_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def template_asarray(
4444
exceptOverflowError:
4545
v=TEagerTensor(np.asarray(a,dtype=np.uint64))
4646
elifisinstance(a,float):
47-
v=TEagerTensor(np.array(a,dtype=np.float32))
47+
v=TEagerTensor(np.array(a,dtype=np.float64))
4848
elifisinstance(a,bool):
4949
v=TEagerTensor(np.array(a,dtype=np.bool_))
5050
elifisinstance(a,str):

‎onnx_array_api/array_api/onnx_numpy.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
"""
44
fromtypingimportAny,Optional
55
importnumpyasnp
6-
fromonnximportTensorProto
76
from ..npx.npx_functionsimport (
87
all,
98
abs,
@@ -16,10 +15,11 @@
1615
reshape,
1716
take,
1817
)
18+
from ..npx.npx_functionsimportfullasgeneric_full
1919
from ..npx.npx_functionsimportonesasgeneric_ones
2020
from ..npx.npx_functionsimportzerosasgeneric_zeros
2121
from ..npx.npx_numpy_tensorsimportEagerNumpyTensor
22-
from ..npx.npx_typesimportDType,ElemType,TensorType,OptParType
22+
from ..npx.npx_typesimportDType,ElemType,TensorType,OptParType,ParType,Scalar
2323
from ._onnx_commonimporttemplate_asarray
2424
from .import_finalize_array_api
2525

@@ -31,6 +31,7 @@
3131
"astype",
3232
"empty",
3333
"equal",
34+
"full",
3435
"isdtype",
3536
"isfinite",
3637
"isnan",
@@ -58,7 +59,7 @@ def asarray(
5859

5960
defones(
6061
shape:TensorType[ElemType.int64,"I", (None,)],
61-
dtype:OptParType[DType]=DType(TensorProto.FLOAT),
62+
dtype:OptParType[DType]=None,
6263
order:OptParType[str]="C",
6364
)->TensorType[ElemType.numerics,"T"]:
6465
ifisinstance(shape,tuple):
@@ -76,7 +77,7 @@ def ones(
7677

7778
defempty(
7879
shape:TensorType[ElemType.int64,"I", (None,)],
79-
dtype:OptParType[DType]=DType(TensorProto.FLOAT),
80+
dtype:OptParType[DType]=None,
8081
order:OptParType[str]="C",
8182
)->TensorType[ElemType.numerics,"T"]:
8283
raiseRuntimeError(
@@ -87,7 +88,7 @@ def empty(
8788

8889
defzeros(
8990
shape:TensorType[ElemType.int64,"I", (None,)],
90-
dtype:OptParType[DType]=DType(TensorProto.FLOAT),
91+
dtype:OptParType[DType]=None,
9192
order:OptParType[str]="C",
9293
)->TensorType[ElemType.numerics,"T"]:
9394
ifisinstance(shape,tuple):
@@ -103,6 +104,32 @@ def zeros(
103104
returngeneric_zeros(shape,dtype=dtype,order=order)
104105

105106

107+
deffull(
108+
shape:TensorType[ElemType.int64,"I", (None,)],
109+
fill_value:ParType[Scalar]=None,
110+
dtype:OptParType[DType]=None,
111+
order:OptParType[str]="C",
112+
)->TensorType[ElemType.numerics,"T"]:
113+
iffill_valueisNone:
114+
raiseTypeError("fill_value cannot be None")
115+
value=fill_value
116+
ifisinstance(shape,tuple):
117+
returngeneric_full(
118+
EagerNumpyTensor(np.array(shape,dtype=np.int64)),
119+
fill_value=value,
120+
dtype=dtype,
121+
order=order,
122+
)
123+
ifisinstance(shape,int):
124+
returngeneric_full(
125+
EagerNumpyTensor(np.array([shape],dtype=np.int64)),
126+
fill_value=value,
127+
dtype=dtype,
128+
order=order,
129+
)
130+
returngeneric_full(shape,fill_value=value,dtype=dtype,order=order)
131+
132+
106133
def_finalize():
107134
"""
108135
Adds common attributes to Array API defined in this modules

‎onnx_array_api/npx/npx_core_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def wrapper(*inputs, **kwargs):
169169
new_inputs.append(i)
170170
elifisinstance(i, (int,float)):
171171
new_inputs.append(
172-
np.array([i],dtype=np.int64ifisinstance(i,int)elsenp.float32)
172+
np.array([i],dtype=np.int64ifisinstance(i,int)elsenp.float64)
173173
)
174174
elifisinstance(i,str):
175175
new_inputs.append(Input(i))

‎onnx_array_api/npx/npx_functions.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,15 @@
1515
SequenceType,
1616
TensorType,
1717
TupleType,
18+
Scalar,
1819
)
1920
from .npx_varimportVar
2021

2122

2223
def_cstv(x):
2324
ifisinstance(x,Var):
2425
returnx
25-
ifisinstance(x, (int,float,np.ndarray)):
26+
ifisinstance(x, (int,float,bool,np.ndarray)):
2627
returncst(x)
2728
raiseTypeError(f"Unexpected constant type{type(x)}.")
2829

@@ -376,6 +377,42 @@ def expit(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.numerics
376377
returnvar(x,op="Sigmoid")
377378

378379

380+
@npxapi_inline
381+
deffull(
382+
shape:TensorType[ElemType.int64,"I", (None,)],
383+
dtype:OptParType[DType]=None,
384+
fill_value:ParType[Scalar]=None,
385+
order:OptParType[str]="C",
386+
)->TensorType[ElemType.numerics,"T"]:
387+
"""
388+
Implements :func:`numpy.full`.
389+
"""
390+
iforder!="C":
391+
raiseRuntimeError(f"order={order!r} != 'C' not supported.")
392+
iffill_valueisNone:
393+
raiseTypeError("fill_value cannot be None.")
394+
ifdtypeisNone:
395+
ifisinstance(fill_value,bool):
396+
dtype=DType(TensorProto.BOOL)
397+
elifisinstance(fill_value,int):
398+
dtype=DType(TensorProto.INT64)
399+
elifisinstance(fill_value,float):
400+
dtype=DType(TensorProto.DOUBLE)
401+
else:
402+
raiseTypeError(
403+
f"Unexpected type{type(fill_value)} for fill_value={fill_value!r}."
404+
)
405+
ifisinstance(fill_value, (float,int,bool)):
406+
value=make_tensor(
407+
name="cst",data_type=dtype.code,dims=[1],vals=[fill_value]
408+
)
409+
else:
410+
raiseNotImplementedError(
411+
f"Unexpected type ({type(fill_value)} for fill_value={fill_value!r}."
412+
)
413+
returnvar(shape,value=value,op="ConstantOfShape")
414+
415+
379416
@npxapi_inline
380417
deffloor(x:TensorType[ElemType.numerics,"T"])->TensorType[ElemType.numerics,"T"]:
381418
"See :func:`numpy.floor`."
@@ -464,7 +501,7 @@ def matmul(
464501
@npxapi_inline
465502
defones(
466503
shape:TensorType[ElemType.int64,"I", (None,)],
467-
dtype:OptParType[DType]=DType(TensorProto.FLOAT),
504+
dtype:OptParType[DType]=None,
468505
order:OptParType[str]="C",
469506
)->TensorType[ElemType.numerics,"T"]:
470507
"""
@@ -473,7 +510,7 @@ def ones(
473510
iforder!="C":
474511
raiseRuntimeError(f"order={order!r} != 'C' not supported.")
475512
ifdtypeisNone:
476-
dtype=DType(TensorProto.FLOAT)
513+
dtype=DType(TensorProto.DOUBLE)
477514
returnvar(
478515
shape,
479516
value=make_tensor(name="one",data_type=dtype.code,dims=[1],vals=[1]),
@@ -674,7 +711,7 @@ def where(
674711
@npxapi_inline
675712
defzeros(
676713
shape:TensorType[ElemType.int64,"I", (None,)],
677-
dtype:OptParType[DType]=DType(TensorProto.FLOAT),
714+
dtype:OptParType[DType]=None,
678715
order:OptParType[str]="C",
679716
)->TensorType[ElemType.numerics,"T"]:
680717
"""
@@ -683,7 +720,7 @@ def zeros(
683720
iforder!="C":
684721
raiseRuntimeError(f"order={order!r} != 'C' not supported.")
685722
ifdtypeisNone:
686-
dtype=DType(TensorProto.FLOAT)
723+
dtype=DType(TensorProto.DOUBLE)
687724
returnvar(
688725
shape,
689726
value=make_tensor(name="zero",data_type=dtype.code,dims=[1],vals=[0]),

‎onnx_array_api/npx/npx_graph_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -784,7 +784,7 @@ def to_onnx(
784784
node_inputs.append(input_name)
785785
continue
786786

787-
ifisinstance(i, (int,float)):
787+
ifisinstance(i, (int,float,bool)):
788788
ni=np.array(i)
789789
c=Cst(ni)
790790
input_name=self._unique(var._prefix)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp