Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit934a139

Browse files
xadupresdpython
andauthored
Support function arange in Array API (#19)
* add arange* introduce OptTensorType* add more tests* better error handling* add kwargs_to_input* fix inconcistencies* improvments* fix one type issue* issue with windows* set* remove unnecessary code* improvments* fix names* fix missing name* fix arange* fix arange* fix unit test for windows---------Co-authored-by: xavier dupré <xavier.dupre@gmail.com>
1 parentc1f0a77 commit934a139

19 files changed

+518
-140
lines changed

‎.gitignore‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,4 @@ _doc/_static/viz.js
2424
_unittests/ut__main/*.png
2525
_unittests/ut__main/_cache/*
2626
_unittests/ut__main/*.html
27+
_unittests/.hypothesis/*

‎_doc/api/index.rst‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@ API
88

99
array_api
1010
npx_functions
11-
npx_var
1211
npx_jit
13-
npx_annot
1412
npx_numpy
13+
npx_types
14+
npx_var
1515
onnx_tools
1616
ort
1717
plotting

‎_doc/api/npx_annot.rst‎renamed to ‎_doc/api/npx_types.rst‎

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,40 @@
1-
=============
21
npx.npx_types
32
=============
43

54
DType
6-
=====
5+
+++++
76

87
..autoclass::onnx_array_api.npx.npx_types.DType
98
:members:
109

11-
Annotations
12-
===========
13-
1410
ElemType
1511
++++++++
1612

1713
..autoclass::onnx_array_api.npx.npx_types.ElemType
1814
:members:
1915

20-
ParType
21-
+++++++
22-
23-
..autoclass::onnx_array_api.npx.npx_types.ParType
24-
:members:
25-
2616
OptParType
2717
++++++++++
2818

2919
..autoclass::onnx_array_api.npx.npx_types.OptParType
3020
:members:
3121

32-
TensorType
33-
++++++++++
22+
OptTensorType
23+
+++++++++++++
3424

35-
..autoclass::onnx_array_api.npx.npx_types.TensorType
25+
..autoclass::onnx_array_api.npx.npx_types.OptTensorType
26+
:members:
27+
28+
ParType
29+
+++++++
30+
31+
..autoclass::onnx_array_api.npx.npx_types.ParType
32+
:members:
33+
34+
Scalar
35+
++++++
36+
37+
..autoclass::onnx_array_api.npx.npx_types.Scalar
3638
:members:
3739

3840
SequenceType
@@ -41,6 +43,18 @@ SequenceType
4143
..autoclass::onnx_array_api.npx.npx_types.SequenceType
4244
:members:
4345

46+
ShapeType
47+
+++++++++
48+
49+
..autoclass::onnx_array_api.npx.npx_types.ShapeType
50+
:members:
51+
52+
TensorType
53+
++++++++++
54+
55+
..autoclass::onnx_array_api.npx.npx_types.TensorType
56+
:members:
57+
4458
TupleType
4559
+++++++++
4660

‎_doc/api/npx_var.rst‎

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,16 @@ Cst, Input
1515

1616
..autoclass::onnx_array_api.npx.npx_var.Input
1717
:members:
18+
19+
ManyIdentity
20+
++++++++++++
21+
22+
..autoclass::onnx_array_api.npx.npx_var.ManyIdentity
23+
:members:
24+
25+
Par
26+
+++
27+
28+
..autoclass::onnx_array_api.npx.npx_var.Par
29+
:members:
30+

‎_unittests/onnx-numpy-skips.txt‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# API failures
22
# see https://github.com/data-apis/array-api-tests/blob/master/numpy-skips.txt
3-
array_api_tests/test_creation_functions.py::test_arange
3+
array_api_tests/test_creation_functions.py::test_asarray_scalars
4+
# array_api_tests/test_creation_functions.py::test_arange
45
array_api_tests/test_creation_functions.py::test_asarray_arrays
56
array_api_tests/test_creation_functions.py::test_empty
67
array_api_tests/test_creation_functions.py::test_empty_like

‎_unittests/test_array_api.sh‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
2-
pytest ../array-api-tests/array_api_tests/test_creation_functions.py::test_asarray_scalars||exit 1
2+
pytest ../array-api-tests/array_api_tests/test_creation_functions.py::test_arange||exit 1
33
# pytest ../array-api-tests/array_api_tests/test_creation_functions.py --help
4-
pytest ../array-api-tests/array_api_tests/test_creation_functions.py --hypothesis-explain --skips-file=_unittests/onnx-numpy-skips.txt||exit 1
4+
pytest ../array-api-tests/array_api_tests/test_creation_functions.py --hypothesis-explain --skips-file=_unittests/onnx-numpy-skips.txt||exit 1

‎_unittests/ut_array_api/test_onnx_numpy.py‎

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
importsys
12
importunittest
23
importnumpyasnp
34
fromonnx_array_api.ext_test_caseimportExtTestCase
@@ -19,6 +20,22 @@ def test_zeros(self):
1920
a=xp.absolute(mat)
2021
self.assertEqualArray(np.absolute(mat.numpy()),a.numpy())
2122

23+
deftest_arange_default(self):
24+
a=EagerTensor(np.array([0],dtype=np.int64))
25+
b=EagerTensor(np.array([2],dtype=np.int64))
26+
mat=xp.arange(a,b)
27+
matnp=mat.numpy()
28+
self.assertEqual(matnp.shape, (2,))
29+
self.assertEqualArray(matnp,np.arange(0,2).astype(np.int64))
30+
31+
deftest_arange_step(self):
32+
a=EagerTensor(np.array([4],dtype=np.int64))
33+
s=EagerTensor(np.array([2],dtype=np.int64))
34+
mat=xp.arange(a,step=s)
35+
matnp=mat.numpy()
36+
self.assertEqual(matnp.shape, (2,))
37+
self.assertEqualArray(matnp,np.arange(4,step=2).astype(np.int64))
38+
2239
deftest_zeros_none(self):
2340
c=EagerTensor(np.array([4,5],dtype=np.int64))
2441
mat=xp.zeros(c)
@@ -52,7 +69,27 @@ def test_full_bool(self):
5269
self.assertNotEmpty(matnp[0,0])
5370
self.assertEqualArray(matnp,np.full((4,5),False))
5471

72+
deftest_arange_int00a(self):
73+
a=EagerTensor(np.array([0],dtype=np.int64))
74+
b=EagerTensor(np.array([0],dtype=np.int64))
75+
mat=xp.arange(a,b)
76+
matnp=mat.numpy()
77+
self.assertEqual(matnp.shape, (0,))
78+
expected=np.arange(0,0)
79+
ifsys.platform=="win32":
80+
expected=expected.astype(np.int64)
81+
self.assertEqualArray(matnp,expected)
82+
83+
deftest_arange_int00(self):
84+
mat=xp.arange(0,0)
85+
matnp=mat.numpy()
86+
self.assertEqual(matnp.shape, (0,))
87+
expected=np.arange(0,0)
88+
ifsys.platform=="win32":
89+
expected=expected.astype(np.int64)
90+
self.assertEqualArray(matnp,expected)
91+
5592

5693
if__name__=="__main__":
57-
TestOnnxNumpy().test_zeros_none()
94+
TestOnnxNumpy().test_arange_int00()
5895
unittest.main(verbosity=2)

‎_unittests/ut_npx/test_npx.py‎

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@
103103
Int64,
104104
OptParType,
105105
TensorType,
106+
OptTensorType,
106107
)
107108
fromonnx_array_api.npx.npx_varimportInput,Var
108109

@@ -125,35 +126,62 @@ def test_shape_inference(self):
125126
self.assertEqual(output.type.tensor_type.elem_type,TensorProto.FLOAT)
126127

127128
deftest_tensor(self):
128-
dt=TensorType["float32"]
129+
dt=TensorType["float32","F32"]
129130
self.assertEqual(len(dt.dtypes),1)
130131
self.assertEqual(dt.dtypes[0].dtype,ElemType.float32)
131132
self.assertEmpty(dt.shape)
132-
self.assertEqual(dt.type_name(),"TensorType['float32']")
133+
self.assertEqual(dt.type_name(),"TensorType['float32', 'F32']")
133134

134-
dt=TensorType["float32"]
135+
dt=TensorType["float32","F32"]
135136
self.assertEqual(len(dt.dtypes),1)
136137
self.assertEqual(dt.dtypes[0].dtype,ElemType.float32)
137-
self.assertEqual(dt.type_name(),"TensorType['float32']")
138+
self.assertEqual(dt.type_name(),"TensorType['float32', 'F32']")
138139

139-
dt=TensorType[np.float32]
140+
dt=TensorType[np.float32,"F32"]
140141
self.assertEqual(len(dt.dtypes),1)
141142
self.assertEqual(dt.dtypes[0].dtype,ElemType.float32)
142-
self.assertEqual(dt.type_name(),"TensorType['float32']")
143+
self.assertEqual(dt.type_name(),"TensorType['float32', 'F32']")
143144
self.assertEmpty(dt.shape)
144145

145-
dt=TensorType[np.str_]
146+
dt=TensorType[np.str_,"TEXT"]
146147
self.assertEqual(len(dt.dtypes),1)
147148
self.assertEqual(dt.dtypes[0].dtype,ElemType.str_)
148-
self.assertEqual(dt.type_name(),"TensorType[strings]")
149+
self.assertEqual(dt.type_name(),"TensorType[strings, 'TEXT']")
150+
self.assertEmpty(dt.shape)
151+
152+
self.assertRaise(lambda:TensorType[None],TypeError)
153+
self.assertRaise(lambda:TensorType[{np.float32,np.str_}],TypeError)
154+
155+
deftest_opt_tensor(self):
156+
dt=OptTensorType["float32","F32"]
157+
self.assertEqual(len(dt.dtypes),1)
158+
self.assertEqual(dt.dtypes[0].dtype,ElemType.float32)
159+
self.assertEmpty(dt.shape)
160+
self.assertEqual(dt.type_name(),"OptTensorType['float32', 'F32']")
161+
162+
dt=OptTensorType["float32","F32"]
163+
self.assertEqual(len(dt.dtypes),1)
164+
self.assertEqual(dt.dtypes[0].dtype,ElemType.float32)
165+
self.assertEqual(dt.type_name(),"OptTensorType['float32', 'F32']")
166+
167+
dt=OptTensorType[np.float32,"F32"]
168+
self.assertEqual(len(dt.dtypes),1)
169+
self.assertEqual(dt.dtypes[0].dtype,ElemType.float32)
170+
self.assertEqual(dt.type_name(),"OptTensorType['float32', 'F32']")
171+
self.assertEmpty(dt.shape)
172+
173+
dt=OptTensorType[np.str_,"TEXT"]
174+
self.assertEqual(len(dt.dtypes),1)
175+
self.assertEqual(dt.dtypes[0].dtype,ElemType.str_)
176+
self.assertEqual(dt.type_name(),"OptTensorType[strings, 'TEXT']")
149177
self.assertEmpty(dt.shape)
150178

151179
self.assertRaise(lambda:TensorType[None],TypeError)
152180
self.assertRaise(lambda:TensorType[{np.float32,np.str_}],TypeError)
153181

154182
deftest_superset(self):
155-
t1=TensorType[ElemType.numerics]
156-
t2=TensorType[ElemType.float64]
183+
t1=TensorType[ElemType.numerics,"T"]
184+
t2=TensorType[ElemType.float64,"F64"]
157185
self.assertTrue(t1.issuperset(t2))
158186
t1=Float32[None]
159187
t2=Float32[None]
@@ -167,14 +195,14 @@ def test_superset(self):
167195
t1=Float32["N"]
168196
t2=Float32[5]
169197
self.assertTrue(t1.issuperset(t2))
170-
t1=TensorType[ElemType.int64]
198+
t1=TensorType[ElemType.int64,"I"]
171199
t2=Int64[1]
172200
self.assertTrue(t1.issuperset(t2))
173201

174202
deftest_sig(self):
175203
deflocal1(
176-
x:TensorType[ElemType.floats],
177-
)->TensorType[ElemType.floats]:
204+
x:TensorType[ElemType.floats,"T"],
205+
)->TensorType[ElemType.floats,"T"]:
178206
returnx
179207

180208
deflocal2(
@@ -2536,13 +2564,17 @@ def test_numpy_all_empty_axis_1(self):
25362564
got=ref.run(None, {"A":data})
25372565
self.assertEqualArray(y,got[0])
25382566

2539-
@unittest.skipIf(True,reason="Fails to follow Array API")
2540-
deftest_get_item(self):
2567+
deftest_get_item_b(self):
25412568
a=EagerNumpyTensor(np.array([True],dtype=np.bool_))
25422569
i=a[0]
25432570
self.assertEqualArray(i.numpy(),a.numpy()[0])
25442571

2572+
deftest_get_item_i8(self):
2573+
a=EagerNumpyTensor(np.array([5,6],dtype=np.int8))
2574+
i=a[0]
2575+
self.assertEqualArray(i.numpy(),a.numpy()[0])
2576+
25452577

25462578
if__name__=="__main__":
2547-
#TestNpx().test_get_item()
2579+
TestNpx().test_filter()
25482580
unittest.main(verbosity=2)
File renamed without changes.

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp