Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit2fc79f6

Browse files
authored
Add full_like for the array API (#26)
* Add full_like for the array API* improvment* fix full_like
1 parentd248c16 commit2fc79f6

File tree

12 files changed

+127
-20
lines changed

12 files changed

+127
-20
lines changed

‎_unittests/onnx-numpy-skips.txt‎

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ array_api_tests/test_creation_functions.py::test_asarray_arrays
66
array_api_tests/test_creation_functions.py::test_empty
77
array_api_tests/test_creation_functions.py::test_empty_like
88
array_api_tests/test_creation_functions.py::test_eye
9-
array_api_tests/test_creation_functions.py::test_full_like
109
array_api_tests/test_creation_functions.py::test_linspace
1110
array_api_tests/test_creation_functions.py::test_meshgrid
1211
array_api_tests/test_creation_functions.py::test_zeros_like

‎_unittests/ut_array_api/test_hypothesis_array_api.py‎

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def fctonx(x, kw):
140140

141141

142142
if__name__=="__main__":
143-
cl=TestHypothesisArraysApis()
144-
cl.setUpClass()
145-
cl.test_scalar_strategies()
143+
#cl = TestHypothesisArraysApis()
144+
#cl.setUpClass()
145+
#cl.test_scalar_strategies()
146146
unittest.main(verbosity=2)

‎_unittests/ut_array_api/test_onnx_numpy.py‎

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,25 @@ def test_ones_like_uint16(self):
112112
expected=np.array(1,dtype=np.uint16)
113113
self.assertEqualArray(expected,z.numpy())
114114

115+
deftest_full_like(self):
116+
c=EagerTensor(np.array(False))
117+
expected=np.full_like(c.numpy(),fill_value=False)
118+
mat=xp.full_like(c,fill_value=False)
119+
matnp=mat.numpy()
120+
self.assertEqual(matnp.shape,tuple())
121+
self.assertEqualArray(expected,matnp)
122+
123+
deftest_full_like_mx(self):
124+
c=EagerTensor(np.array([],dtype=np.uint8))
125+
expected=np.full_like(c.numpy(),fill_value=0)
126+
mat=xp.full_like(c,fill_value=0)
127+
matnp=mat.numpy()
128+
self.assertEqualArray(expected,matnp)
129+
115130

116131
if__name__=="__main__":
117-
# TestOnnxNumpy().test_ones_like()
132+
# import logging
133+
134+
# logging.basicConfig(level=logging.DEBUG)
135+
# TestOnnxNumpy().test_full_like_mx()
118136
unittest.main(verbosity=2)

‎azure-pipelines.yml‎

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -246,9 +246,10 @@ jobs:
246246
architecture:'x64'
247247
-script:gcc --version
248248
displayName:'gcc version'
249-
-script:|
250-
brew update
251-
displayName: 'brew update'
249+
#- script: brew upgrade
250+
# displayName: 'brew upgrade'
251+
#- script: brew update
252+
# displayName: 'brew update'
252253
-script:export
253254
displayName:'export'
254255
-script:gcc --version

‎onnx_array_api/array_api/__init__.py‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"empty",
1919
"equal",
2020
"full",
21+
"full_like",
2122
"isdtype",
2223
"isfinite",
2324
"isinf",

‎onnx_array_api/array_api/_onnx_common.py‎

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
absasgeneric_abs,
2121
arangeasgeneric_arange,
2222
fullasgeneric_full,
23+
full_likeasgeneric_full_like,
2324
onesasgeneric_ones,
2425
zerosasgeneric_zeros,
2526
)
@@ -177,6 +178,23 @@ def full(
177178
returngeneric_full(shape,fill_value=value,dtype=dtype,order=order)
178179

179180

181+
deffull_like(
182+
TEagerTensor:type,
183+
x:TensorType[ElemType.allowed,"T"],
184+
/,
185+
fill_value:ParType[Scalar]=None,
186+
*,
187+
dtype:OptParType[DType]=None,
188+
order:OptParType[str]="C",
189+
)->EagerTensor[TensorType[ElemType.allowed,"TR"]]:
190+
ifdtypeisNone:
191+
ifisinstance(fill_value,TEagerTensor):
192+
dtype=fill_value.dtype
193+
elifisinstance(x,TEagerTensor):
194+
dtype=x.dtype
195+
returngeneric_full_like(x,fill_value=fill_value,dtype=dtype,order=order)
196+
197+
180198
defones(
181199
TEagerTensor:type,
182200
shape:EagerTensor[TensorType[ElemType.int64,"I", (None,)]],

‎onnx_array_api/npx/npx_functions.py‎

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,9 +275,9 @@ def astype(
275275
ifdtypeisint:
276276
to=DType(TensorProto.INT64)
277277
elifdtypeisfloat:
278-
to=DType(TensorProto.FLOAT64)
278+
to=DType(TensorProto.DOUBLE)
279279
elifdtypeisbool:
280-
to=DType(TensorProto.FLOAT64)
280+
to=DType(TensorProto.BOOL)
281281
elifdtypeisstr:
282282
to=DType(TensorProto.STRING)
283283
else:
@@ -511,6 +511,49 @@ def full(
511511
returnvar(shape,value=value,op="ConstantOfShape")
512512

513513

514+
@npxapi_inline
515+
deffull_like(
516+
x:TensorType[ElemType.allowed,"T"],
517+
/,
518+
*,
519+
fill_value:ParType[Scalar]=None,
520+
dtype:OptParType[DType]=None,
521+
order:OptParType[str]="C",
522+
)->TensorType[ElemType.numerics,"T"]:
523+
"""
524+
Implements :func:`numpy.zeros`.
525+
"""
526+
iforder!="C":
527+
raiseRuntimeError(f"order={order!r} != 'C' not supported.")
528+
iffill_valueisNone:
529+
raiseTypeError("fill_value cannot be None.")
530+
ifdtypeisNone:
531+
ifisinstance(fill_value,bool):
532+
dtype=DType(TensorProto.BOOL)
533+
elifisinstance(fill_value,int):
534+
dtype=DType(TensorProto.INT64)
535+
elifisinstance(fill_value,float):
536+
dtype=DType(TensorProto.DOUBLE)
537+
else:
538+
raiseTypeError(
539+
f"Unexpected type{type(fill_value)} for fill_value={fill_value!r} "
540+
f"and dtype={dtype!r}."
541+
)
542+
ifisinstance(fill_value, (float,int,bool)):
543+
value=make_tensor(
544+
name="cst",data_type=dtype.code,dims=[1],vals=[fill_value]
545+
)
546+
else:
547+
raiseNotImplementedError(
548+
f"Unexpected type ({type(fill_value)} for fill_value={fill_value!r}."
549+
)
550+
551+
v=var(x.shape,value=value,op="ConstantOfShape")
552+
ifdtypeisNone:
553+
returnvar(v,x,op="CastLike")
554+
returnv
555+
556+
514557
@npxapi_inline
515558
deffloor(
516559
x:TensorType[ElemType.numerics,"T"],/

‎onnx_array_api/npx/npx_jit_eager.py‎

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def info(
5858
kwargs:Optional[Dict[str,Any]]=None,
5959
key:Optional[Tuple[Any, ...]]=None,
6060
onx:Optional[ModelProto]=None,
61+
output:Optional[Any]=None,
6162
):
6263
"""
6364
Logs a status.
@@ -93,6 +94,8 @@ def info(
9394
""ifargsisNoneelsestr(args),
9495
""ifkwargsisNoneelsestr(kwargs),
9596
)
97+
ifoutputisnotNone:
98+
logger.debug("==== [%s]",output)
9699

97100
defstatus(self,me:str)->str:
98101
"""
@@ -517,7 +520,7 @@ def jit_call(self, *values, **kwargs):
517520
f"f={self.f} from module{self.f.__module__!r} "
518521
f"onnx=\n---\n{text}\n---\n{self.onxs[key]}"
519522
)frome
520-
self.info("-","jit_call")
523+
self.info("-","jit_call",output=res)
521524
returnres
522525

523526

@@ -737,11 +740,13 @@ def __call__(self, *args, already_eager=False, **kwargs):
737740
try:
738741
res=self.f(*values,**kwargs)
739742
except (AttributeError,TypeError)ase:
740-
inp1=", ".join(map(str,map(type,args)))
741-
inp2=", ".join(map(str,map(type,values)))
743+
inp1=", ".join(map(str,map(lambdaa:type(a).__name__,args)))
744+
inp2=", ".join(map(str,map(lambdaa:type(a).__name__,values)))
742745
raiseTypeError(
743-
f"Unexpected types, input types are{inp1} "
744-
f"and{inp2}, kwargs={kwargs}."
746+
f"Unexpected types, input types are args=[{inp1}], "
747+
f"values=[{inp2}], kwargs={kwargs}. "
748+
f"(values = self._preprocess_constants(args)) "
749+
f"args={args}, values={values}"
745750
)frome
746751

747752
ifisinstance(res,EagerTensor)or (

‎onnx_array_api/npx/npx_numpy_tensors.py‎

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
fromonnximportModelProto,TensorProto
55
from ..referenceimportExtendedReferenceEvaluator
66
from .._helpersimportnp_dtype_to_tensor_dtype
7-
from .npx_numpy_tensors_opsimportConstantOfShape
87
from .npx_tensorsimportEagerTensor,JitTensor
98
from .npx_typesimportDType,TensorType
109

@@ -36,7 +35,7 @@ def __init__(
3635
onx:ModelProto,
3736
f:Callable,
3837
):
39-
self.ref=ExtendedReferenceEvaluator(onx,new_ops=[ConstantOfShape])
38+
self.ref=ExtendedReferenceEvaluator(onx)
4039
self.input_names=input_names
4140
self.tensor_class=tensor_class
4241
self._f=f

‎onnx_array_api/npx/npx_types.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __eq__(self, dt: "DType") -> bool:
6868
ifdtisbool:
6969
returnself.code_==TensorProto.BOOL
7070
ifdtisfloat:
71-
returnself.code_==TensorProto.FLOAT64
71+
returnself.code_==TensorProto.DOUBLE
7272
ifisinstance(dt,list):
7373
returnFalse
7474
ifdtinElemType.numpy_map:

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp