Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Add function Eye to the Array API#29

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Merged
sdpython merged 4 commits intomainfromasa
Jul 2, 2023
Merged
Show file tree
Hide file tree
Changes fromall commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions_unittests/onnx-numpy-skips.txt
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
# API failures
# see https://github.com/data-apis/array-api-tests/blob/master/numpy-skips.txt
array_api_tests/test_creation_functions.py::test_asarray_scalars
array_api_tests/test_creation_functions.py::test_arange
# uses __setitem__
array_api_tests/test_creation_functions.py::test_asarray_arrays
array_api_tests/test_creation_functions.py::test_empty
array_api_tests/test_creation_functions.py::test_empty_like
array_api_tests/test_creation_functions.py::test_eye
array_api_tests/test_creation_functions.py::test_linspace
array_api_tests/test_creation_functions.py::test_meshgrid
2 changes: 1 addition & 1 deletion_unittests/test_array_api.sh
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py::test_zeros_like || exit 1
pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py::test_eye || exit 1
# pytest ../array-api-tests/array_api_tests/test_creation_functions.py --help
pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py --hypothesis-explain --skips-file=_unittests/onnx-numpy-skips.txt || exit 1
72 changes: 72 additions & 0 deletions_unittests/ut_array_api/test_hypothesis_array_api.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -39,6 +39,7 @@ def sh(x):

class TestHypothesisArraysApis(ExtTestCase):
MAX_ARRAY_SIZE = 10000
SQRT_MAX_ARRAY_SIZE = int(10000**0.5)
VERSION = "2021.12"

@classmethod
Expand DownExpand Up@@ -138,9 +139,80 @@ def fctonx(x, kw):
fctonx()
self.assertEqual(len(args_onxp), len(args_np))

def test_square_sizes_strategies(self):
dtypes = dict(
integer_dtypes=self.xps.integer_dtypes(),
uinteger_dtypes=self.xps.unsigned_integer_dtypes(),
floating_dtypes=self.xps.floating_dtypes(),
numeric_dtypes=self.xps.numeric_dtypes(),
boolean_dtypes=self.xps.boolean_dtypes(),
scalar_dtypes=self.xps.scalar_dtypes(),
)

dtypes_onnx = dict(
integer_dtypes=self.onxps.integer_dtypes(),
uinteger_dtypes=self.onxps.unsigned_integer_dtypes(),
floating_dtypes=self.onxps.floating_dtypes(),
numeric_dtypes=self.onxps.numeric_dtypes(),
boolean_dtypes=self.onxps.boolean_dtypes(),
scalar_dtypes=self.onxps.scalar_dtypes(),
)

for k, vnp in dtypes.items():
vonxp = dtypes_onnx[k]
anp = self.xps.arrays(dtype=vnp, shape=shapes(self.xps))
aonxp = self.onxps.arrays(dtype=vonxp, shape=shapes(self.onxps))
self.assertNotEmpty(anp)
self.assertNotEmpty(aonxp)

args_np = []

kws = array_api_kwargs(k=strategies.integers(), dtype=self.xps.numeric_dtypes())
sqrt_sizes = strategies.integers(0, self.SQRT_MAX_ARRAY_SIZE)
ncs = strategies.none() | sqrt_sizes

@given(n_rows=sqrt_sizes, n_cols=ncs, kw=kws)
def fctnp(n_rows, n_cols, kw):
base = np.asarray(0)
e = np.eye(n_rows, n_cols)
self.assertNotEmpty(e.dtype)
self.assertIsInstance(e, base.__class__)
e = np.eye(n_rows, n_cols, **kw)
self.assertNotEmpty(e.dtype)
self.assertIsInstance(e, base.__class__)
args_np.append((n_rows, n_cols, kw))

fctnp()
self.assertEqual(len(args_np), 100)

args_onxp = []

kws = array_api_kwargs(
k=strategies.integers(), dtype=self.onxps.numeric_dtypes()
)
sqrt_sizes = strategies.integers(0, self.SQRT_MAX_ARRAY_SIZE)
ncs = strategies.none() | sqrt_sizes

@given(n_rows=sqrt_sizes, n_cols=ncs, kw=kws)
def fctonx(n_rows, n_cols, kw):
base = onxp.asarray(0)
e = onxp.eye(n_rows, n_cols)
self.assertIsInstance(e, base.__class__)
self.assertNotEmpty(e.dtype)
e = onxp.eye(n_rows, n_cols, **kw)
self.assertNotEmpty(e.dtype)
self.assertIsInstance(e, base.__class__)
args_onxp.append((n_rows, n_cols, kw))

fctonx()
self.assertEqual(len(args_onxp), len(args_np))


if __name__ == "__main__":
# cl = TestHypothesisArraysApis()
# cl.setUpClass()
# cl.test_scalar_strategies()
# import logging

# logging.basicConfig(level=logging.DEBUG)
unittest.main(verbosity=2)
20 changes: 19 additions & 1 deletion_unittests/ut_array_api/test_onnx_numpy.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -142,10 +142,28 @@ def test_as_array(self):
self.assertEqual(r.dtype, DType(TensorProto.UINT64))
self.assertEqual(r.numpy(), 9223372036854775809)

def test_eye(self):
nr, nc = xp.asarray(4), xp.asarray(4)
expected = np.eye(nr.numpy(), nc.numpy())
got = xp.eye(nr, nc)
self.assertEqualArray(expected, got.numpy())

def test_eye_nosquare(self):
nr, nc = xp.asarray(4), xp.asarray(5)
expected = np.eye(nr.numpy(), nc.numpy())
got = xp.eye(nr, nc)
self.assertEqualArray(expected, got.numpy())

def test_eye_k(self):
nr = xp.asarray(4)
expected = np.eye(nr.numpy(), k=1)
got = xp.eye(nr, k=1)
self.assertEqualArray(expected, got.numpy())


if __name__ == "__main__":
# import logging

# logging.basicConfig(level=logging.DEBUG)
#TestOnnxNumpy().test_as_array()
TestOnnxNumpy().test_eye()
unittest.main(verbosity=2)
1 change: 1 addition & 0 deletionsonnx_array_api/array_api/__init__.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -17,6 +17,7 @@
"astype",
"empty",
"equal",
"eye",
"full",
"full_like",
"isdtype",
Expand Down
21 changes: 21 additions & 0 deletionsonnx_array_api/array_api/_onnx_common.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
from typing import Any, Optional
import warnings
import numpy as np
from onnx import TensorProto

with warnings.catch_warnings():
warnings.simplefilter("ignore")
Expand All@@ -19,6 +20,8 @@
from ..npx.npx_functions import (
abs as generic_abs,
arange as generic_arange,
copy as copy_inline,
eye as generic_eye,
full as generic_full,
full_like as generic_full_like,
ones as generic_ones,
Expand DownExpand Up@@ -185,6 +188,24 @@ def full(
return generic_full(shape, fill_value=value, dtype=dtype, order=order)


def eye(
TEagerTensor: type,
n_rows: TensorType[ElemType.int64, "I"],
n_cols: OptTensorType[ElemType.int64, "I"] = None,
/,
*,
k: ParType[int] = 0,
dtype: ParType[DType] = DType(TensorProto.DOUBLE),
):
if isinstance(n_rows, int):
n_rows = TEagerTensor(np.array(n_rows, dtype=np.int64))
if n_cols is None:
n_cols = n_rows
elif isinstance(n_cols, int):
n_cols = TEagerTensor(np.array(n_cols, dtype=np.int64))
return generic_eye(n_rows, n_cols, k=k, dtype=dtype)


def full_like(
TEagerTensor: type,
x: TensorType[ElemType.allowed, "T"],
Expand Down
24 changes: 24 additions & 0 deletionsonnx_array_api/npx/npx_functions.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -473,6 +473,30 @@ def expit(
return var(x, op="Sigmoid")


@npxapi_inline
def eye(
n_rows: TensorType[ElemType.int64, "I"],
n_cols: TensorType[ElemType.int64, "I"],
/,
*,
k: ParType[int] = 0,
dtype: ParType[DType] = DType(TensorProto.DOUBLE),
):
"See :func:`numpy.eye`."
shape = cst(np.array([-1], dtype=np.int64))
shape = var(
var(n_rows, shape, op="Reshape"),
var(n_cols, shape, op="Reshape"),
axis=0,
op="Concat",
)
zero = zeros(shape, dtype=dtype)
res = var(zero, k=k, op="EyeLike")
if dtype is not None:
return var(res, to=dtype.code, op="Cast")
return res


@npxapi_inline
def full(
shape: TensorType[ElemType.int64, "I", (None,)],
Expand Down
10 changes: 10 additions & 0 deletionsonnx_array_api/npx/npx_graph_builder.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -230,6 +230,11 @@ def make_node(
new_kwargs[k] = v.value
elif isinstance(v, DType):
new_kwargs[k] = v.code
elif isinstance(v, int):
try:
new_kwargs[k] = int(np.array(v, dtype=np.int64))
except OverflowError:
new_kwargs[k] = int(np.iinfo(np.int64).max)
else:
new_kwargs[k] = v

Expand All@@ -246,6 +251,11 @@ def make_node(
f"Unable to create node {op!r}, with inputs={inputs}, "
f"outputs={outputs}, domain={domain!r}, new_kwargs={new_kwargs}."
) from e
except ValueError as e:
raise ValueError(
f"Unable to create node {op!r}, with inputs={inputs}, "
f"outputs={outputs}, domain={domain!r}, new_kwargs={new_kwargs}."
) from e
for p in protos:
node.attribute.append(p)
if attribute_protos is not None:
Expand Down
9 changes: 8 additions & 1 deletiononnx_array_api/npx/npx_jit_eager.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -510,11 +510,18 @@ def jit_call(self, *values, **kwargs):
from ..plotting.text_plot import onnx_simple_text_plot

text = onnx_simple_text_plot(self.onxs[key])

def catch_len(x):
try:
return len(x)
except TypeError:
return 0

raise RuntimeError(
f"Unable to run function for key={key!r}, "
f"types={[type(x) for x in values]}, "
f"dtypes={[getattr(x, 'dtype', type(x)) for x in values]}, "
f"shapes={[getattr(x, 'shape',len(x)) for x in values]}, "
f"shapes={[getattr(x, 'shape',catch_len(x)) for x in values]}, "
f"kwargs={kwargs}, "
f"self.input_to_kwargs_={self.input_to_kwargs_}, "
f"f={self.f} from module {self.f.__module__!r} "
Expand Down
4 changes: 0 additions & 4 deletionsonnx_array_api/reference/evaluator.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -7,10 +7,6 @@
from .ops.op_cast_like import CastLike_15, CastLike_19
from .ops.op_constant_of_shape import ConstantOfShape

import onnx

print(onnx.__file__)


logger = getLogger("onnx-array-api-eval")

Expand Down

[8]ページ先頭

©2009-2025 Movatter.jp