Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit1be44a7

Browse files
sdpythonxadupre
andauthored
Enables linspace (#30)
* Enables test_asarray_scalars* Add support for array api linspace* lint* improves consistency for linspace* fix linspace* disable asarrays_arrays* fix strategies* aapi---------Co-authored-by: Xavier Dupre <xadupre@microsoft.com>
1 parent9b0b5d6 commit1be44a7

File tree

12 files changed

+409
-19
lines changed

12 files changed

+409
-19
lines changed

‎_unittests/onnx-numpy-skips.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# API failures
22
# see https://github.com/data-apis/array-api-tests/blob/master/numpy-skips.txt
33
# uses __setitem__
4-
array_api_tests/test_creation_functions.py::test_asarray_arrays
4+
#array_api_tests/test_creation_functions.py::test_asarray_arrays
55
array_api_tests/test_creation_functions.py::test_empty
66
array_api_tests/test_creation_functions.py::test_empty_like
7-
array_api_tests/test_creation_functions.py::test_linspace
7+
#array_api_tests/test_creation_functions.py::test_linspace
88
array_api_tests/test_creation_functions.py::test_meshgrid

‎_unittests/test_array_api.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
2-
pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py::test_eye||exit 1
2+
pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py::test_full_like||exit 1
33
# pytest ../array-api-tests/array_api_tests/test_creation_functions.py --help
44
pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py --hypothesis-explain --skips-file=_unittests/onnx-numpy-skips.txt||exit 1

‎_unittests/ut_array_api/test_hypothesis_array_api.py

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
importnumpyasnp
66
fromoperatorimportmul
77
fromhypothesisimportgiven
8-
fromonnx_array_api.ext_test_caseimportExtTestCase
8+
fromonnx_array_api.ext_test_caseimportExtTestCase,ignore_warnings
99
fromonnx_array_api.array_apiimportonnx_numpyasonxp
1010
fromhypothesisimportstrategies
1111
fromhypothesis.extraimportarray_api
@@ -207,11 +207,58 @@ def fctonx(n_rows, n_cols, kw):
207207
fctonx()
208208
self.assertEqual(len(args_onxp),len(args_np))
209209

210+
@ignore_warnings(UserWarning)
211+
deftest_square_shared_types(self):
212+
dtypes=self.onxps.scalar_dtypes()
213+
shared_dtypes=strategies.shared(dtypes,key="dtype")
214+
215+
defshapes(**kw):
216+
kw.setdefault("min_dims",0)
217+
kw.setdefault("min_side",0)
218+
returnself.onxps.array_shapes(**kw).filter(
219+
lambdashape:prod(iforiinshapeifi)<self.MAX_ARRAY_SIZE
220+
)
221+
222+
@strategies.composite
223+
defkwargs(draw,**kw):
224+
result= {}
225+
fork,stratinkw.items():
226+
ifdraw(strategies.booleans()):
227+
result[k]=draw(strat)
228+
returnresult
229+
230+
@strategies.composite
231+
deffull_like_fill_values(draw):
232+
kw=draw(
233+
strategies.shared(
234+
kwargs(dtype=strategies.none()|self.onxps.scalar_dtypes()),
235+
key="full_like_kw",
236+
)
237+
)
238+
dtype=kw.get("dtype",None)ordraw(shared_dtypes)
239+
returndraw(self.onxps.from_dtype(dtype))
240+
241+
args= []
242+
sh=shapes()
243+
xa=self.onxps.arrays(dtype=shared_dtypes,shape=sh)
244+
fu=full_like_fill_values()
245+
kws=strategies.shared(
246+
kwargs(dtype=strategies.none()|self.onxps.scalar_dtypes()),
247+
key="full_like_kw",
248+
)
249+
250+
@given(x=xa,fill_value=fu,kw=kws)
251+
deffctonp(x,fill_value,kw):
252+
args.append((x,fill_value,kw))
253+
254+
fctonp()
255+
self.assertEqual(len(args),100)
256+
210257

211258
if__name__=="__main__":
212-
#cl = TestHypothesisArraysApis()
213-
#cl.setUpClass()
214-
#cl.test_scalar_strategies()
259+
cl=TestHypothesisArraysApis()
260+
cl.setUpClass()
261+
cl.test_square_shared_types()
215262
# import logging
216263

217264
# logging.basicConfig(level=logging.DEBUG)

‎_unittests/ut_array_api/test_onnx_numpy.py

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,14 @@
22
importunittest
33
importnumpyasnp
44
fromonnximportTensorProto
5-
fromonnx_array_api.ext_test_caseimportExtTestCase
5+
fromonnx_array_api.ext_test_caseimportExtTestCase,ignore_warnings
66
fromonnx_array_api.array_apiimportonnx_numpyasxp
77
fromonnx_array_api.npx.npx_typesimportDType
88
fromonnx_array_api.npx.npx_numpy_tensorsimportEagerNumpyTensorasEagerTensor
9+
fromonnx_array_api.npx.npx_functionsimportlinspaceaslinspace_inline
10+
fromonnx_array_api.npx.npx_typesimportFloat64,Int64
11+
fromonnx_array_api.npx.npx_varimportInput
12+
fromonnx_array_api.referenceimportExtendedReferenceEvaluator
913

1014

1115
classTestOnnxNumpy(ExtTestCase):
@@ -22,6 +26,7 @@ def test_zeros(self):
2226
a=xp.absolute(mat)
2327
self.assertEqualArray(np.absolute(mat.numpy()),a.numpy())
2428

29+
@ignore_warnings(DeprecationWarning)
2530
deftest_arange_default(self):
2631
a=EagerTensor(np.array([0],dtype=np.int64))
2732
b=EagerTensor(np.array([2],dtype=np.int64))
@@ -30,6 +35,7 @@ def test_arange_default(self):
3035
self.assertEqual(matnp.shape, (2,))
3136
self.assertEqualArray(matnp,np.arange(0,2).astype(np.int64))
3237

38+
@ignore_warnings(DeprecationWarning)
3339
deftest_arange_step(self):
3440
a=EagerTensor(np.array([4],dtype=np.int64))
3541
s=EagerTensor(np.array([2],dtype=np.int64))
@@ -78,6 +84,7 @@ def test_full_bool(self):
7884
self.assertNotEmpty(matnp[0,0])
7985
self.assertEqualArray(matnp,np.full((4,5),False))
8086

87+
@ignore_warnings(DeprecationWarning)
8188
deftest_arange_int00a(self):
8289
a=EagerTensor(np.array([0],dtype=np.int64))
8390
b=EagerTensor(np.array([0],dtype=np.int64))
@@ -89,6 +96,7 @@ def test_arange_int00a(self):
8996
expected=expected.astype(np.int64)
9097
self.assertEqualArray(matnp,expected)
9198

99+
@ignore_warnings(DeprecationWarning)
92100
deftest_arange_int00(self):
93101
mat=xp.arange(0,0)
94102
matnp=mat.numpy()
@@ -160,10 +168,94 @@ def test_eye_k(self):
160168
got=xp.eye(nr,k=1)
161169
self.assertEqualArray(expected,got.numpy())
162170

171+
deftest_linspace_int(self):
172+
a=EagerTensor(np.array([0],dtype=np.int64))
173+
b=EagerTensor(np.array([6],dtype=np.int64))
174+
c=EagerTensor(np.array(3,dtype=np.int64))
175+
mat=xp.linspace(a,b,c)
176+
matnp=mat.numpy()
177+
expected=np.linspace(a.numpy(),b.numpy(),c.numpy()).astype(np.int64)
178+
self.assertEqualArray(expected,matnp)
179+
180+
deftest_linspace_int5(self):
181+
a=EagerTensor(np.array([0],dtype=np.int64))
182+
b=EagerTensor(np.array([5],dtype=np.int64))
183+
c=EagerTensor(np.array(3,dtype=np.int64))
184+
mat=xp.linspace(a,b,c)
185+
matnp=mat.numpy()
186+
expected=np.linspace(a.numpy(),b.numpy(),c.numpy()).astype(np.int64)
187+
self.assertEqualArray(expected,matnp)
188+
189+
deftest_linspace_float(self):
190+
a=EagerTensor(np.array([0.5],dtype=np.float64))
191+
b=EagerTensor(np.array([5.5],dtype=np.float64))
192+
c=EagerTensor(np.array(2,dtype=np.int64))
193+
mat=xp.linspace(a,b,c)
194+
matnp=mat.numpy()
195+
expected=np.linspace(a.numpy(),b.numpy(),c.numpy())
196+
self.assertEqualArray(expected,matnp)
197+
198+
deftest_linspace_float_noendpoint(self):
199+
a=EagerTensor(np.array([0.5],dtype=np.float64))
200+
b=EagerTensor(np.array([5.5],dtype=np.float64))
201+
c=EagerTensor(np.array(2,dtype=np.int64))
202+
mat=xp.linspace(a,b,c,endpoint=0)
203+
matnp=mat.numpy()
204+
expected=np.linspace(a.numpy(),b.numpy(),c.numpy(),endpoint=0)
205+
self.assertEqualArray(expected,matnp)
206+
207+
@ignore_warnings((RuntimeWarning,DeprecationWarning))# division by zero
208+
deftest_linspace_zero(self):
209+
expected=np.linspace(0.0,0.0,0,endpoint=False)
210+
mat=xp.linspace(0.0,0.0,0,endpoint=False)
211+
matnp=mat.numpy()
212+
self.assertEqualArray(expected,matnp)
213+
214+
@ignore_warnings((RuntimeWarning,DeprecationWarning))# division by zero
215+
deftest_linspace_zero_one(self):
216+
expected=np.linspace(0.0,0.0,1,endpoint=True)
217+
218+
f=linspace_inline(Input("start"),Input("stop"),Input("num"))
219+
onx=f.to_onnx(
220+
constraints={
221+
"start":Float64[None],
222+
"stop":Float64[None],
223+
"num":Int64[None],
224+
(0,False):Float64[None],
225+
}
226+
)
227+
ref=ExtendedReferenceEvaluator(onx)
228+
got=ref.run(
229+
None,
230+
{
231+
"start":np.array(0,dtype=np.float64),
232+
"stop":np.array(0,dtype=np.float64),
233+
"num":np.array(1,dtype=np.int64),
234+
},
235+
)
236+
self.assertEqualArray(expected,got[0])
237+
238+
mat=xp.linspace(0.0,0.0,1,endpoint=True)
239+
matnp=mat.numpy()
240+
241+
self.assertEqualArray(expected,matnp)
242+
243+
deftest_slice_minus_one(self):
244+
g=EagerTensor(np.array([0.0]))
245+
expected=g.numpy()[:-1]
246+
got=g[:-1]
247+
self.assertEqualArray(expected,got.numpy())
248+
249+
deftest_linspace_bug1(self):
250+
expected=np.linspace(16777217.0,0.0,1)
251+
mat=xp.linspace(16777217.0,0.0,1)
252+
matnp=mat.numpy()
253+
self.assertEqualArray(expected,matnp)
254+
163255

164256
if__name__=="__main__":
165257
# import logging
166258

167259
# logging.basicConfig(level=logging.DEBUG)
168-
TestOnnxNumpy().test_eye()
260+
TestOnnxNumpy().test_linspace_float_noendpoint()
169261
unittest.main(verbosity=2)

‎_unittests/ut_npx/test_npx.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
fromonnx.referenceimportReferenceEvaluator
2121
fromonnx.shape_inferenceimportinfer_shapes
2222

23-
fromonnx_array_api.ext_test_caseimportExtTestCase
23+
fromonnx_array_api.ext_test_caseimportExtTestCase,ignore_warnings
24+
fromonnx_array_api.referenceimportExtendedReferenceEvaluator
2425
fromonnx_array_api.npximportElemType,eager_onnx,jit_onnx
2526
fromonnx_array_api.npx.npx_core_apiimport (
2627
cst,
@@ -60,6 +61,7 @@
6061
fromonnx_array_api.npx.npx_functionsimporthstackashstack_inline
6162
fromonnx_array_api.npx.npx_functionsimportidentityasidentity_inline
6263
fromonnx_array_api.npx.npx_functionsimportisnanasisnan_inline
64+
fromonnx_array_api.npx.npx_functionsimportlinspaceaslinspace_inline
6365
fromonnx_array_api.npx.npx_functionsimportlogaslog_inline
6466
fromonnx_array_api.npx.npx_functionsimportlog1paslog1p_inline
6567
fromonnx_array_api.npx.npx_functionsimportmatmulasmatmul_inline
@@ -1654,6 +1656,7 @@ def test_squeeze(self):
16541656
got=ref.run(None, {"A":x})
16551657
self.assertEqualArray(z,got[0])
16561658

1659+
@ignore_warnings(DeprecationWarning)
16571660
deftest_squeeze_noaxis(self):
16581661
f=squeeze_inline(copy_inline(Input("A")))
16591662
self.assertIsInstance(f,Var)
@@ -2574,6 +2577,51 @@ def test_get_item_i8(self):
25742577
i=a[0]
25752578
self.assertEqualArray(i.numpy(),a.numpy()[0])
25762579

2580+
@ignore_warnings(RuntimeWarning)
2581+
deftest_linspace_big_inline(self):
2582+
# linspace(5, 0, 1) --> [5] even with endpoint=True
2583+
f=linspace_inline(Input("A"),Input("B"),Input("C"))
2584+
self.assertIsInstance(f,Var)
2585+
onx=f.to_onnx(
2586+
constraints={
2587+
0:Int64[None],
2588+
1:Int64[None],
2589+
2:Int64[None],
2590+
(0,False):Int64[None],
2591+
}
2592+
)
2593+
2594+
start=np.array(16777217.0,dtype=np.float64)
2595+
stop=np.array(0.0,dtype=np.float64)
2596+
num=np.array(1,dtype=np.int64)
2597+
y=np.linspace(start,stop,num)
2598+
ref=ExtendedReferenceEvaluator(onx)
2599+
got=ref.run(None, {"A":start,"B":stop,"C":num})
2600+
self.assertEqualArray(y,got[0])
2601+
2602+
@ignore_warnings(RuntimeWarning)
2603+
deftest_linspace_inline(self):
2604+
# linspace(0, 5, 1)
2605+
f=linspace_inline(Input("A"),Input("B"),Input("C"))
2606+
self.assertIsInstance(f,Var)
2607+
onx=f.to_onnx(
2608+
constraints={
2609+
0:Int64[None],
2610+
1:Int64[None],
2611+
2:Int64[None],
2612+
(0,False):Int64[None],
2613+
}
2614+
)
2615+
2616+
start=np.array(0,dtype=np.float64)
2617+
stop=np.array(5,dtype=np.float64)
2618+
num=np.array(1,dtype=np.int64)
2619+
y=np.linspace(start,stop,num)
2620+
ref=ExtendedReferenceEvaluator(onx)
2621+
got=ref.run(None, {"A":start,"B":stop,"C":num})
2622+
self.assertEqualArray(y,got[0])
2623+
25772624

25782625
if__name__=="__main__":
2626+
TestNpx().test_linspace_inline()
25792627
unittest.main(verbosity=2)

‎onnx_array_api/array_api/__init__.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
"isfinite",
2525
"isinf",
2626
"isnan",
27+
"linspace",
2728
"ones",
2829
"ones_like",
2930
"reshape",
@@ -40,11 +41,18 @@ def _finfo(dtype):
4041
"""
4142
dt=dtype.np_dtypeifisinstance(dtype,DType)elsedtype
4243
res=np.finfo(dt)
43-
d=res.__dict__.copy()
44+
d= {}
45+
fork,vinres.__dict__.items():
46+
ifk.startswith("__"):
47+
continue
48+
ifisinstance(v, (np.float32,np.float64,np.float16)):
49+
d[k]=float(v)
50+
else:
51+
d[k]=v
4452
d["dtype"]=DType(np_dtype_to_tensor_dtype(dt))
4553
nres=type("finfo", (res.__class__,),d)
46-
setattr(nres,"smallest_normal",res.smallest_normal)
47-
setattr(nres,"tiny",res.tiny)
54+
setattr(nres,"smallest_normal",float(res.smallest_normal))
55+
setattr(nres,"tiny",float(res.tiny))
4856
returnnres
4957

5058

@@ -54,11 +62,30 @@ def _iinfo(dtype):
5462
"""
5563
dt=dtype.np_dtypeifisinstance(dtype,DType)elsedtype
5664
res=np.iinfo(dt)
57-
d=res.__dict__.copy()
65+
d= {}
66+
fork,vinres.__dict__.items():
67+
ifk.startswith("__"):
68+
continue
69+
ifisinstance(
70+
v,
71+
(
72+
np.int16,
73+
np.int32,
74+
np.int64,
75+
np.uint16,
76+
np.uint32,
77+
np.uint64,
78+
np.int8,
79+
np.uint8,
80+
),
81+
):
82+
d[k]=int(v)
83+
else:
84+
d[k]=v
5885
d["dtype"]=DType(np_dtype_to_tensor_dtype(dt))
59-
nres=type("finfo", (res.__class__,),d)
60-
setattr(nres,"min",res.min)
61-
setattr(nres,"max",res.max)
86+
nres=type("iinfo", (res.__class__,),d)
87+
setattr(nres,"min",int(res.min))
88+
setattr(nres,"max",int(res.max))
6289
returnnres
6390

6491

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp