Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitc6a3718

Browse files
authored
Fixes asarray for the Array API (#25)
* Fixes asarray for the Array API* move
1 parent61eec9d commitc6a3718

File tree

7 files changed

+77
-17
lines changed

7 files changed

+77
-17
lines changed

‎_unittests/ut_array_api/test_hypothesis_array_api.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
importwarnings
33
fromosimportgetenv
44
fromfunctoolsimportreduce
5+
importnumpyasnp
56
fromoperatorimportmul
67
fromhypothesisimportgiven
78
fromonnx_array_api.ext_test_caseimportExtTestCase
@@ -89,24 +90,49 @@ def test_scalar_strategies(self):
8990

9091
args_np= []
9192

93+
xx=self.xps.arrays(dtype=dtypes["integer_dtypes"],shape=shapes(self.xps))
94+
kws=array_api_kwargs(dtype=strategies.none()|self.xps.scalar_dtypes())
95+
9296
@given(
93-
x=self.xps.arrays(dtype=dtypes["integer_dtypes"],shape=shapes(self.xps)),
94-
kw=array_api_kwargs(dtype=strategies.none()|self.xps.scalar_dtypes()),
97+
x=xx,
98+
kw=kws,
9599
)
96-
deffct(x,kw):
100+
deffctnp(x,kw):
101+
asa1=np.asarray(x)
102+
asa2=np.asarray(x,**kw)
103+
self.assertEqual(asa1.shape,asa2.shape)
97104
args_np.append((x,kw))
98105

99-
fct()
106+
fctnp()
100107
self.assertEqual(len(args_np),100)
101108

102109
args_onxp= []
103110

104111
xshape=shapes(self.onxps)
105112
xx=self.onxps.arrays(dtype=dtypes_onnx["integer_dtypes"],shape=xshape)
106-
kw=array_api_kwargs(dtype=strategies.none()|self.onxps.scalar_dtypes())
113+
kws=array_api_kwargs(dtype=strategies.none()|self.onxps.scalar_dtypes())
107114

108-
@given(x=xx,kw=kw)
115+
@given(x=xx,kw=kws)
109116
deffctonx(x,kw):
117+
asa=np.asarray(x.numpy())
118+
try:
119+
asp=onxp.asarray(x)
120+
exceptExceptionase:
121+
raiseAssertionError(f"asarray fails with x={x!r}, asp={asa!r}.")frome
122+
try:
123+
self.assertEqualArray(asa,asp.numpy())
124+
exceptAssertionErrorase:
125+
raiseAssertionError(
126+
f"x={x!r} kw={kw!r} asa={asa!r}, asp={asp!r}"
127+
)frome
128+
ifkw:
129+
try:
130+
asp2=onxp.asarray(x,**kw)
131+
exceptExceptionase:
132+
raiseAssertionError(
133+
f"asarray fails with x={x!r}, kw={kw!r}, asp={asa!r}."
134+
)frome
135+
self.assertEqual(asp.shape,asp2.shape)
110136
args_onxp.append((x,kw))
111137

112138
fctonx()

‎onnx_array_api/array_api/_onnx_common.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
fromtypingimportAny,Optional
2+
importwarnings
23
importnumpyasnp
4+
5+
withwarnings.catch_warnings():
6+
warnings.simplefilter("ignore")
7+
fromnumpy.array_api._array_objectimportArray
38
from ..npx.npx_typesimport (
49
DType,
510
ElemType,
@@ -77,6 +82,10 @@ def asarray(
7782
v=TEagerTensor(np.array(a,dtype=np.str_))
7883
elifisinstance(a,list):
7984
v=TEagerTensor(np.array(a))
85+
elifisinstance(a,np.ndarray):
86+
v=TEagerTensor(a)
87+
elifisinstance(a,Array):
88+
v=TEagerTensor(np.asarray(a))
8089
else:
8190
raiseRuntimeError(f"Unexpected type{type(a)} for the first input.")
8291
ifdtypeisnotNone:

‎onnx_array_api/npx/npx_numpy_tensors.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
importwarnings
12
fromtypingimportAny,Callable,List,Optional,Tuple
23
importnumpyasnp
34
fromonnximportModelProto,TensorProto
@@ -221,13 +222,18 @@ def __bool__(self):
221222
ifself.shape== (0,):
222223
returnFalse
223224
iflen(self.shape)!=0:
224-
raiseValueError(
225-
f"Conversion to bool only works for scalar, not for{self!r}."
225+
warnings.warn(
226+
f"Conversion to bool only works for scalar, not for{self!r}, "
227+
f"bool(...)={bool(self._tensor)}."
226228
)
229+
try:
230+
returnbool(self._tensor)
231+
exceptValueErrorase:
232+
raiseValueError(f"Unable to convert{self} to bool.")frome
227233
returnbool(self._tensor)
228234

229235
def__int__(self):
230-
"Implicit conversion tobool."
236+
"Implicit conversion toint."
231237
iflen(self.shape)!=0:
232238
raiseValueError(
233239
f"Conversion to bool only works for scalar, not for{self!r}."
@@ -249,7 +255,7 @@ def __int__(self):
249255
returnint(self._tensor)
250256

251257
def__float__(self):
252-
"Implicit conversion tobool."
258+
"Implicit conversion tofloat."
253259
iflen(self.shape)!=0:
254260
raiseValueError(
255261
f"Conversion to bool only works for scalar, not for{self!r}."
@@ -261,11 +267,24 @@ def __float__(self):
261267
DType(TensorProto.BFLOAT16),
262268
}:
263269
raiseTypeError(
264-
f"Conversion toint only works for float scalar, "
270+
f"Conversion tofloat only works for float scalar, "
265271
f"not for dtype={self.dtype}."
266272
)
267273
returnfloat(self._tensor)
268274

275+
def__iter__(self):
276+
"""
277+
The :epkg:`Array API` does not define this function (2022/12).
278+
This method raises an exception with a better error message.
279+
"""
280+
warnings.warn(
281+
f"Iterators are not implemented in the generic case. "
282+
f"Every function using them cannot be converted into ONNX "
283+
f"(tensors -{type(self)})."
284+
)
285+
forrowinself._tensor:
286+
yieldself.__class__(row)
287+
269288

270289
classJitNumpyTensor(NumpyTensor,JitTensor):
271290
"""

‎onnx_array_api/npx/npx_tensors.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,9 @@ def __iter__(self):
3535
This method raises an exception with a better error message.
3636
"""
3737
raiseArrayApiError(
38-
"Iterators are not implemented in the generic case. "
39-
"Every function using them cannot be converted into ONNX."
38+
f"Iterators are not implemented in the generic case. "
39+
f"Every function using them cannot be converted into ONNX "
40+
f"(tensors -{type(self)})."
4041
)
4142

4243
@staticmethod

‎onnx_array_api/npx/npx_types.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,16 @@ def __eq__(self, dt: "DType") -> bool:
5959
returnFalse
6060
ifdt.__class__isDType:
6161
returnself.code_==dt.code_
62-
ifisinstance(dt, (int,bool,str)):
62+
ifisinstance(dt, (int,bool,str,float)):
6363
returnFalse
64+
ifdtisint:
65+
returnself.code_==TensorProto.INT64
6466
ifdtisstr:
6567
returnself.code_==TensorProto.STRING
6668
ifdtisbool:
6769
returnself.code_==TensorProto.BOOL
70+
ifdtisfloat:
71+
returnself.code_==TensorProto.FLOAT64
6872
ifisinstance(dt,list):
6973
returnFalse
7074
ifdtinElemType.numpy_map:

‎onnx_array_api/npx/npx_var.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -607,8 +607,9 @@ def __iter__(self):
607607
This method raises an exception with a better error message.
608608
"""
609609
raiseArrayApiError(
610-
"Iterators are not implemented in the generic case. "
611-
"Every function using them cannot be converted into ONNX."
610+
f"Iterators are not implemented in the generic case. "
611+
f"Every function using them cannot be converted into ONNX "
612+
f"(Var -{type(self)})."
612613
)
613614

614615
def_binary_op(self,ov:"Var",op_name:str,**kwargs)->"Var":

‎requirements-dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ black
33
coverage
44
flake8
55
furo
6-
hypothesis<6.80.0
6+
hypothesis
77
isort
88
joblib
99
lightgbm

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp