Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Add full_like for the array API#26

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Merged
sdpython merged 4 commits intomainfromfulll
Jul 2, 2023
Merged
Show file tree
Hide file tree
Changes fromall commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion_unittests/onnx-numpy-skips.txt
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -6,7 +6,6 @@ array_api_tests/test_creation_functions.py::test_asarray_arrays
array_api_tests/test_creation_functions.py::test_empty
array_api_tests/test_creation_functions.py::test_empty_like
array_api_tests/test_creation_functions.py::test_eye
array_api_tests/test_creation_functions.py::test_full_like
array_api_tests/test_creation_functions.py::test_linspace
array_api_tests/test_creation_functions.py::test_meshgrid
array_api_tests/test_creation_functions.py::test_zeros_like
6 changes: 3 additions & 3 deletions_unittests/ut_array_api/test_hypothesis_array_api.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -140,7 +140,7 @@ def fctonx(x, kw):


if __name__ == "__main__":
cl = TestHypothesisArraysApis()
cl.setUpClass()
cl.test_scalar_strategies()
#cl = TestHypothesisArraysApis()
#cl.setUpClass()
#cl.test_scalar_strategies()
unittest.main(verbosity=2)
20 changes: 19 additions & 1 deletion_unittests/ut_array_api/test_onnx_numpy.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -112,7 +112,25 @@ def test_ones_like_uint16(self):
expected = np.array(1, dtype=np.uint16)
self.assertEqualArray(expected, z.numpy())

def test_full_like(self):
c = EagerTensor(np.array(False))
expected = np.full_like(c.numpy(), fill_value=False)
mat = xp.full_like(c, fill_value=False)
matnp = mat.numpy()
self.assertEqual(matnp.shape, tuple())
self.assertEqualArray(expected, matnp)

def test_full_like_mx(self):
c = EagerTensor(np.array([], dtype=np.uint8))
expected = np.full_like(c.numpy(), fill_value=0)
mat = xp.full_like(c, fill_value=0)
matnp = mat.numpy()
self.assertEqualArray(expected, matnp)


if __name__ == "__main__":
# TestOnnxNumpy().test_ones_like()
# import logging

# logging.basicConfig(level=logging.DEBUG)
# TestOnnxNumpy().test_full_like_mx()
unittest.main(verbosity=2)
7 changes: 4 additions & 3 deletionsazure-pipelines.yml
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -246,9 +246,10 @@ jobs:
architecture: 'x64'
- script: gcc --version
displayName: 'gcc version'
- script: |
brew update
displayName: 'brew update'
#- script: brew upgrade
# displayName: 'brew upgrade'
#- script: brew update
# displayName: 'brew update'
- script: export
displayName: 'export'
- script: gcc --version
Expand Down
1 change: 1 addition & 0 deletionsonnx_array_api/array_api/__init__.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -18,6 +18,7 @@
"empty",
"equal",
"full",
"full_like",
"isdtype",
"isfinite",
"isinf",
Expand Down
18 changes: 18 additions & 0 deletionsonnx_array_api/array_api/_onnx_common.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -20,6 +20,7 @@
abs as generic_abs,
arange as generic_arange,
full as generic_full,
full_like as generic_full_like,
ones as generic_ones,
zeros as generic_zeros,
)
Expand DownExpand Up@@ -177,6 +178,23 @@ def full(
return generic_full(shape, fill_value=value, dtype=dtype, order=order)


def full_like(
TEagerTensor: type,
x: TensorType[ElemType.allowed, "T"],
/,
fill_value: ParType[Scalar] = None,
*,
dtype: OptParType[DType] = None,
order: OptParType[str] = "C",
) -> EagerTensor[TensorType[ElemType.allowed, "TR"]]:
if dtype is None:
if isinstance(fill_value, TEagerTensor):
dtype = fill_value.dtype
elif isinstance(x, TEagerTensor):
dtype = x.dtype
return generic_full_like(x, fill_value=fill_value, dtype=dtype, order=order)


def ones(
TEagerTensor: type,
shape: EagerTensor[TensorType[ElemType.int64, "I", (None,)]],
Expand Down
47 changes: 45 additions & 2 deletionsonnx_array_api/npx/npx_functions.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -275,9 +275,9 @@ def astype(
if dtype is int:
to = DType(TensorProto.INT64)
elif dtype is float:
to = DType(TensorProto.FLOAT64)
to = DType(TensorProto.DOUBLE)
elif dtype is bool:
to = DType(TensorProto.FLOAT64)
to = DType(TensorProto.BOOL)
elif dtype is str:
to = DType(TensorProto.STRING)
else:
Expand DownExpand Up@@ -511,6 +511,49 @@ def full(
return var(shape, value=value, op="ConstantOfShape")


@npxapi_inline
def full_like(
x: TensorType[ElemType.allowed, "T"],
/,
*,
fill_value: ParType[Scalar] = None,
dtype: OptParType[DType] = None,
order: OptParType[str] = "C",
) -> TensorType[ElemType.numerics, "T"]:
"""
Implements :func:`numpy.zeros`.
"""
if order != "C":
raise RuntimeError(f"order={order!r} != 'C' not supported.")
if fill_value is None:
raise TypeError("fill_value cannot be None.")
if dtype is None:
if isinstance(fill_value, bool):
dtype = DType(TensorProto.BOOL)
elif isinstance(fill_value, int):
dtype = DType(TensorProto.INT64)
elif isinstance(fill_value, float):
dtype = DType(TensorProto.DOUBLE)
else:
raise TypeError(
f"Unexpected type {type(fill_value)} for fill_value={fill_value!r} "
f"and dtype={dtype!r}."
)
if isinstance(fill_value, (float, int, bool)):
value = make_tensor(
name="cst", data_type=dtype.code, dims=[1], vals=[fill_value]
)
else:
raise NotImplementedError(
f"Unexpected type ({type(fill_value)} for fill_value={fill_value!r}."
)

v = var(x.shape, value=value, op="ConstantOfShape")
if dtype is None:
return var(v, x, op="CastLike")
return v


@npxapi_inline
def floor(
x: TensorType[ElemType.numerics, "T"], /
Expand Down
15 changes: 10 additions & 5 deletionsonnx_array_api/npx/npx_jit_eager.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -58,6 +58,7 @@ def info(
kwargs: Optional[Dict[str, Any]] = None,
key: Optional[Tuple[Any, ...]] = None,
onx: Optional[ModelProto] = None,
output: Optional[Any] = None,
):
"""
Logs a status.
Expand DownExpand Up@@ -93,6 +94,8 @@ def info(
"" if args is None else str(args),
"" if kwargs is None else str(kwargs),
)
if output is not None:
logger.debug("==== [%s]", output)

def status(self, me: str) -> str:
"""
Expand DownExpand Up@@ -517,7 +520,7 @@ def jit_call(self, *values, **kwargs):
f"f={self.f} from module {self.f.__module__!r} "
f"onnx=\n---\n{text}\n---\n{self.onxs[key]}"
) from e
self.info("-", "jit_call")
self.info("-", "jit_call", output=res)
return res


Expand DownExpand Up@@ -737,11 +740,13 @@ def __call__(self, *args, already_eager=False, **kwargs):
try:
res = self.f(*values, **kwargs)
except (AttributeError, TypeError) as e:
inp1 = ", ".join(map(str, map(type, args)))
inp2 = ", ".join(map(str, map(type, values)))
inp1 = ", ".join(map(str, map(lambda a:type(a).__name__, args)))
inp2 = ", ".join(map(str, map(lambda a:type(a).__name__, values)))
raise TypeError(
f"Unexpected types, input types are {inp1} "
f"and {inp2}, kwargs={kwargs}."
f"Unexpected types, input types are args=[{inp1}], "
f"values=[{inp2}], kwargs={kwargs}. "
f"(values = self._preprocess_constants(args)) "
f"args={args}, values={values}"
) from e

if isinstance(res, EagerTensor) or (
Expand Down
3 changes: 1 addition & 2 deletionsonnx_array_api/npx/npx_numpy_tensors.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -4,7 +4,6 @@
from onnx import ModelProto, TensorProto
from ..reference import ExtendedReferenceEvaluator
from .._helpers import np_dtype_to_tensor_dtype
from .npx_numpy_tensors_ops import ConstantOfShape
from .npx_tensors import EagerTensor, JitTensor
from .npx_types import DType, TensorType

Expand DownExpand Up@@ -36,7 +35,7 @@ def __init__(
onx: ModelProto,
f: Callable,
):
self.ref = ExtendedReferenceEvaluator(onx, new_ops=[ConstantOfShape])
self.ref = ExtendedReferenceEvaluator(onx)
self.input_names = input_names
self.tensor_class = tensor_class
self._f = f
Expand Down
2 changes: 1 addition & 1 deletiononnx_array_api/npx/npx_types.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -68,7 +68,7 @@ def __eq__(self, dt: "DType") -> bool:
if dt is bool:
return self.code_ == TensorProto.BOOL
if dt is float:
return self.code_ == TensorProto.FLOAT64
return self.code_ == TensorProto.DOUBLE
if isinstance(dt, list):
return False
if dt in ElemType.numpy_map:
Expand Down
17 changes: 17 additions & 0 deletionsonnx_array_api/reference/evaluator.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,18 @@
from logging import getLogger
from typing import Any, Dict, List, Optional, Union
from onnx import FunctionProto, ModelProto
from onnx.defs import get_schema
from onnx.reference import ReferenceEvaluator
from onnx.reference.op_run import OpRun
from .ops.op_cast_like import CastLike_15, CastLike_19
from .ops.op_constant_of_shape import ConstantOfShape

import onnx

print(onnx.__file__)


logger = getLogger("onnx-array-api-eval")


class ExtendedReferenceEvaluator(ReferenceEvaluator):
Expand All@@ -24,6 +33,7 @@ class ExtendedReferenceEvaluator(ReferenceEvaluator):
default_ops = [
CastLike_15,
CastLike_19,
ConstantOfShape,
]

@staticmethod
Expand DownExpand Up@@ -88,3 +98,10 @@ def __init__(
new_ops=new_ops,
**kwargs,
)

def _log(self, level: int, pattern: str, *args: List[Any]) -> None:
if level < self.verbose:
new_args = [self._log_arg(a) for a in args]
print(pattern % tuple(new_args))
else:
logger.debug(pattern, *args)
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
import numpy as np

from onnx.reference.op_run import OpRun


class ConstantOfShape(OpRun):
@staticmethod
def _process(value):
cst = value[0] if isinstance(value, np.ndarray) else value
cst = value[0] if isinstance(value, np.ndarray) and value.size > 0 else value
if isinstance(value, np.ndarray):
if len(value.shape) == 0:
cst = value
elif value.size > 0:
cst = value.ravel()[0]
else:
raise ValueError(f"Unexpected fill_value={value!r}")
if isinstance(cst, bool):
cst = np.bool_(cst)
elif isinstance(cst, int):
Expand Down

[8]ページ先頭

©2009-2025 Movatter.jp