Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit6bea970

Browse files
authored
Add function Eye to the Array API (#29)
* Add function Eye to the Array API* remove eye* improve* fix overflow
1 parent35cb298 commit6bea970

File tree

10 files changed

+157
-10
lines changed

10 files changed

+157
-10
lines changed

‎_unittests/onnx-numpy-skips.txt

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
# API failures
22
# see https://github.com/data-apis/array-api-tests/blob/master/numpy-skips.txt
3-
array_api_tests/test_creation_functions.py::test_asarray_scalars
4-
array_api_tests/test_creation_functions.py::test_arange
3+
# uses __setitem__
54
array_api_tests/test_creation_functions.py::test_asarray_arrays
65
array_api_tests/test_creation_functions.py::test_empty
76
array_api_tests/test_creation_functions.py::test_empty_like
8-
array_api_tests/test_creation_functions.py::test_eye
97
array_api_tests/test_creation_functions.py::test_linspace
108
array_api_tests/test_creation_functions.py::test_meshgrid

‎_unittests/test_array_api.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
2-
pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py::test_zeros_like||exit 1
2+
pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py::test_eye||exit 1
33
# pytest ../array-api-tests/array_api_tests/test_creation_functions.py --help
44
pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py --hypothesis-explain --skips-file=_unittests/onnx-numpy-skips.txt||exit 1

‎_unittests/ut_array_api/test_hypothesis_array_api.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def sh(x):
3939

4040
classTestHypothesisArraysApis(ExtTestCase):
4141
MAX_ARRAY_SIZE=10000
42+
SQRT_MAX_ARRAY_SIZE=int(10000**0.5)
4243
VERSION="2021.12"
4344

4445
@classmethod
@@ -138,9 +139,80 @@ def fctonx(x, kw):
138139
fctonx()
139140
self.assertEqual(len(args_onxp),len(args_np))
140141

142+
deftest_square_sizes_strategies(self):
143+
dtypes=dict(
144+
integer_dtypes=self.xps.integer_dtypes(),
145+
uinteger_dtypes=self.xps.unsigned_integer_dtypes(),
146+
floating_dtypes=self.xps.floating_dtypes(),
147+
numeric_dtypes=self.xps.numeric_dtypes(),
148+
boolean_dtypes=self.xps.boolean_dtypes(),
149+
scalar_dtypes=self.xps.scalar_dtypes(),
150+
)
151+
152+
dtypes_onnx=dict(
153+
integer_dtypes=self.onxps.integer_dtypes(),
154+
uinteger_dtypes=self.onxps.unsigned_integer_dtypes(),
155+
floating_dtypes=self.onxps.floating_dtypes(),
156+
numeric_dtypes=self.onxps.numeric_dtypes(),
157+
boolean_dtypes=self.onxps.boolean_dtypes(),
158+
scalar_dtypes=self.onxps.scalar_dtypes(),
159+
)
160+
161+
fork,vnpindtypes.items():
162+
vonxp=dtypes_onnx[k]
163+
anp=self.xps.arrays(dtype=vnp,shape=shapes(self.xps))
164+
aonxp=self.onxps.arrays(dtype=vonxp,shape=shapes(self.onxps))
165+
self.assertNotEmpty(anp)
166+
self.assertNotEmpty(aonxp)
167+
168+
args_np= []
169+
170+
kws=array_api_kwargs(k=strategies.integers(),dtype=self.xps.numeric_dtypes())
171+
sqrt_sizes=strategies.integers(0,self.SQRT_MAX_ARRAY_SIZE)
172+
ncs=strategies.none()|sqrt_sizes
173+
174+
@given(n_rows=sqrt_sizes,n_cols=ncs,kw=kws)
175+
deffctnp(n_rows,n_cols,kw):
176+
base=np.asarray(0)
177+
e=np.eye(n_rows,n_cols)
178+
self.assertNotEmpty(e.dtype)
179+
self.assertIsInstance(e,base.__class__)
180+
e=np.eye(n_rows,n_cols,**kw)
181+
self.assertNotEmpty(e.dtype)
182+
self.assertIsInstance(e,base.__class__)
183+
args_np.append((n_rows,n_cols,kw))
184+
185+
fctnp()
186+
self.assertEqual(len(args_np),100)
187+
188+
args_onxp= []
189+
190+
kws=array_api_kwargs(
191+
k=strategies.integers(),dtype=self.onxps.numeric_dtypes()
192+
)
193+
sqrt_sizes=strategies.integers(0,self.SQRT_MAX_ARRAY_SIZE)
194+
ncs=strategies.none()|sqrt_sizes
195+
196+
@given(n_rows=sqrt_sizes,n_cols=ncs,kw=kws)
197+
deffctonx(n_rows,n_cols,kw):
198+
base=onxp.asarray(0)
199+
e=onxp.eye(n_rows,n_cols)
200+
self.assertIsInstance(e,base.__class__)
201+
self.assertNotEmpty(e.dtype)
202+
e=onxp.eye(n_rows,n_cols,**kw)
203+
self.assertNotEmpty(e.dtype)
204+
self.assertIsInstance(e,base.__class__)
205+
args_onxp.append((n_rows,n_cols,kw))
206+
207+
fctonx()
208+
self.assertEqual(len(args_onxp),len(args_np))
209+
141210

142211
if__name__=="__main__":
143212
# cl = TestHypothesisArraysApis()
144213
# cl.setUpClass()
145214
# cl.test_scalar_strategies()
215+
# import logging
216+
217+
# logging.basicConfig(level=logging.DEBUG)
146218
unittest.main(verbosity=2)

‎_unittests/ut_array_api/test_onnx_numpy.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,10 +142,28 @@ def test_as_array(self):
142142
self.assertEqual(r.dtype,DType(TensorProto.UINT64))
143143
self.assertEqual(r.numpy(),9223372036854775809)
144144

145+
deftest_eye(self):
146+
nr,nc=xp.asarray(4),xp.asarray(4)
147+
expected=np.eye(nr.numpy(),nc.numpy())
148+
got=xp.eye(nr,nc)
149+
self.assertEqualArray(expected,got.numpy())
150+
151+
deftest_eye_nosquare(self):
152+
nr,nc=xp.asarray(4),xp.asarray(5)
153+
expected=np.eye(nr.numpy(),nc.numpy())
154+
got=xp.eye(nr,nc)
155+
self.assertEqualArray(expected,got.numpy())
156+
157+
deftest_eye_k(self):
158+
nr=xp.asarray(4)
159+
expected=np.eye(nr.numpy(),k=1)
160+
got=xp.eye(nr,k=1)
161+
self.assertEqualArray(expected,got.numpy())
162+
145163

146164
if__name__=="__main__":
147165
# import logging
148166

149167
# logging.basicConfig(level=logging.DEBUG)
150-
#TestOnnxNumpy().test_as_array()
168+
TestOnnxNumpy().test_eye()
151169
unittest.main(verbosity=2)

‎onnx_array_api/array_api/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"astype",
1818
"empty",
1919
"equal",
20+
"eye",
2021
"full",
2122
"full_like",
2223
"isdtype",

‎onnx_array_api/array_api/_onnx_common.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
fromtypingimportAny,Optional
22
importwarnings
33
importnumpyasnp
4+
fromonnximportTensorProto
45

56
withwarnings.catch_warnings():
67
warnings.simplefilter("ignore")
@@ -19,6 +20,8 @@
1920
from ..npx.npx_functionsimport (
2021
absasgeneric_abs,
2122
arangeasgeneric_arange,
23+
copyascopy_inline,
24+
eyeasgeneric_eye,
2225
fullasgeneric_full,
2326
full_likeasgeneric_full_like,
2427
onesasgeneric_ones,
@@ -185,6 +188,24 @@ def full(
185188
returngeneric_full(shape,fill_value=value,dtype=dtype,order=order)
186189

187190

191+
defeye(
192+
TEagerTensor:type,
193+
n_rows:TensorType[ElemType.int64,"I"],
194+
n_cols:OptTensorType[ElemType.int64,"I"]=None,
195+
/,
196+
*,
197+
k:ParType[int]=0,
198+
dtype:ParType[DType]=DType(TensorProto.DOUBLE),
199+
):
200+
ifisinstance(n_rows,int):
201+
n_rows=TEagerTensor(np.array(n_rows,dtype=np.int64))
202+
ifn_colsisNone:
203+
n_cols=n_rows
204+
elifisinstance(n_cols,int):
205+
n_cols=TEagerTensor(np.array(n_cols,dtype=np.int64))
206+
returngeneric_eye(n_rows,n_cols,k=k,dtype=dtype)
207+
208+
188209
deffull_like(
189210
TEagerTensor:type,
190211
x:TensorType[ElemType.allowed,"T"],

‎onnx_array_api/npx/npx_functions.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,30 @@ def expit(
473473
returnvar(x,op="Sigmoid")
474474

475475

476+
@npxapi_inline
477+
defeye(
478+
n_rows:TensorType[ElemType.int64,"I"],
479+
n_cols:TensorType[ElemType.int64,"I"],
480+
/,
481+
*,
482+
k:ParType[int]=0,
483+
dtype:ParType[DType]=DType(TensorProto.DOUBLE),
484+
):
485+
"See :func:`numpy.eye`."
486+
shape=cst(np.array([-1],dtype=np.int64))
487+
shape=var(
488+
var(n_rows,shape,op="Reshape"),
489+
var(n_cols,shape,op="Reshape"),
490+
axis=0,
491+
op="Concat",
492+
)
493+
zero=zeros(shape,dtype=dtype)
494+
res=var(zero,k=k,op="EyeLike")
495+
ifdtypeisnotNone:
496+
returnvar(res,to=dtype.code,op="Cast")
497+
returnres
498+
499+
476500
@npxapi_inline
477501
deffull(
478502
shape:TensorType[ElemType.int64,"I", (None,)],

‎onnx_array_api/npx/npx_graph_builder.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,11 @@ def make_node(
230230
new_kwargs[k]=v.value
231231
elifisinstance(v,DType):
232232
new_kwargs[k]=v.code
233+
elifisinstance(v,int):
234+
try:
235+
new_kwargs[k]=int(np.array(v,dtype=np.int64))
236+
exceptOverflowError:
237+
new_kwargs[k]=int(np.iinfo(np.int64).max)
233238
else:
234239
new_kwargs[k]=v
235240

@@ -246,6 +251,11 @@ def make_node(
246251
f"Unable to create node{op!r}, with inputs={inputs}, "
247252
f"outputs={outputs}, domain={domain!r}, new_kwargs={new_kwargs}."
248253
)frome
254+
exceptValueErrorase:
255+
raiseValueError(
256+
f"Unable to create node{op!r}, with inputs={inputs}, "
257+
f"outputs={outputs}, domain={domain!r}, new_kwargs={new_kwargs}."
258+
)frome
249259
forpinprotos:
250260
node.attribute.append(p)
251261
ifattribute_protosisnotNone:

‎onnx_array_api/npx/npx_jit_eager.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -510,11 +510,18 @@ def jit_call(self, *values, **kwargs):
510510
from ..plotting.text_plotimportonnx_simple_text_plot
511511

512512
text=onnx_simple_text_plot(self.onxs[key])
513+
514+
defcatch_len(x):
515+
try:
516+
returnlen(x)
517+
exceptTypeError:
518+
return0
519+
513520
raiseRuntimeError(
514521
f"Unable to run function for key={key!r}, "
515522
f"types={[type(x)forxinvalues]}, "
516523
f"dtypes={[getattr(x,'dtype',type(x))forxinvalues]}, "
517-
f"shapes={[getattr(x,'shape',len(x))forxinvalues]}, "
524+
f"shapes={[getattr(x,'shape',catch_len(x))forxinvalues]}, "
518525
f"kwargs={kwargs}, "
519526
f"self.input_to_kwargs_={self.input_to_kwargs_}, "
520527
f"f={self.f} from module{self.f.__module__!r} "

‎onnx_array_api/reference/evaluator.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,6 @@
77
from .ops.op_cast_likeimportCastLike_15,CastLike_19
88
from .ops.op_constant_of_shapeimportConstantOfShape
99

10-
importonnx
11-
12-
print(onnx.__file__)
13-
1410

1511
logger=getLogger("onnx-array-api-eval")
1612

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp