Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Implements ArrayAPI#17

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Merged
sdpython merged 27 commits intomainfromaapi2
Jun 10, 2023
Merged
Show file tree
Hide file tree
Changes from1 commit
Commits
Show all changes
27 commits
Select commitHold shift + click to select a range
945749c
rename ArrayApi into BaseArrayApi
xadupreJun 5, 2023
38c3c0e
Implements ArrayAPI
xadupreJun 5, 2023
178185f
Merge branch 'main' of https://github.com/sdpython/onnx-array-api int…
xadupreJun 5, 2023
f0d4eba
documentation
xadupreJun 5, 2023
0ffa4b4
ci
xadupreJun 5, 2023
293c570
fix ci
xadupreJun 5, 2023
caf2324
xi
xadupreJun 5, 2023
768eb85
ci
xadupreJun 5, 2023
692774e
ci
xadupreJun 5, 2023
5c7dae1
ci
xadupreJun 5, 2023
1a1cd35
ci
xadupreJun 5, 2023
da3a7c8
many changes to follow the Array API
xadupreJun 5, 2023
c600890
more changes
xadupreJun 5, 2023
535cc4a
fix unit test
xadupreJun 5, 2023
5a07911
fix ci
xadupreJun 5, 2023
3481d27
ci
xadupreJun 5, 2023
1990fe8
api
xadupreJun 5, 2023
e0ca8c4
ci
xadupreJun 5, 2023
1e218b5
improvments
xadupreJun 5, 2023
e344011
refactorign
xadupreJun 6, 2023
178e3e9
fix asarray
xadupreJun 6, 2023
550d0dc
new udpates
sdpythonJun 8, 2023
0169b85
fix two bugs
xadupreJun 8, 2023
cb11e55
Add one unit test for empty input
xadupreJun 9, 2023
caa99a7
fix all when shape is empty and has one dimension
xadupreJun 10, 2023
8945d16
fix missing return
xadupreJun 10, 2023
7a979a2
remove the full tests
xadupreJun 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
PrevPrevious commit
NextNext commit
fix asarray
  • Loading branch information
@xadupre
xadupre committedJun 6, 2023
commit178e3e933dbfc64cc6b4a18cd087ad6d1f06d59a
4 changes: 4 additions & 0 deletionsazure-pipelines.yml
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -127,9 +127,13 @@ jobs:
cd array-api-tests
displayName: 'Set API'
- script: |
export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
cd array-api-tests
python -m pytest -xv array_api_tests/test_creation_functions.py::test_zeros
displayName: "test_creation_functions.py::test_zeros"
- script: |
export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
cd array-api-tests
python -m pytest -x array_api_tests
displayName: "all tests"

Expand Down
50 changes: 50 additions & 0 deletionsonnx_array_api/array_api/_onnx_common.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
from typing import Any, Optional
import numpy as np
from ..npx.npx_types import DType
from ..npx.npx_array_api import BaseArrayApi
from ..npx.npx_functions import (
copy as copy_inline,
)


def template_asarray(
TEagerTensor: type,
a: Any,
dtype: Optional[DType] = None,
order: Optional[str] = None,
like: Any = None,
copy: bool = False,
) -> Any:
"""
Converts anything into an array.
"""
if order not in ("C", None):
raise NotImplementedError(f"asarray is not implemented for order={order!r}.")
if like is not None:
raise NotImplementedError(
f"asarray is not implemented for like != None (type={type(like)})."
)
if isinstance(a, BaseArrayApi):
if copy:
if dtype is None:
return copy_inline(a)
return copy_inline(a).astype(dtype=dtype)
if dtype is None:
return a
return a.astype(dtype=dtype)

if isinstance(a, int):
v = TEagerTensor(np.array(a, dtype=np.int64))
elif isinstance(a, float):
v = TEagerTensor(np.array(a, dtype=np.float32))
elif isinstance(a, bool):
v = TEagerTensor(np.array(a, dtype=np.bool_))
elif isinstance(a, str):
v = TEagerTensor(np.array(a, dtype=np.str_))
else:
raise RuntimeError(f"Unexpected type {type(a)} for the first input.")
if dtype is not None:
vt = v.astype(dtype=dtype)
else:
vt = v
return vt
36 changes: 4 additions & 32 deletionsonnx_array_api/array_api/onnx_numpy.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -4,13 +4,11 @@
from typing import Any, Optional
import numpy as np
from onnx import TensorProto
from ..npx.npx_array_api import BaseArrayApi
from ..npx.npx_functions import (
all,
abs,
absolute,
astype,
copy as copy_inline,
equal,
isdtype,
reshape,
Expand All@@ -19,6 +17,7 @@
from ..npx.npx_functions import zeros as generic_zeros
from ..npx.npx_numpy_tensors import EagerNumpyTensor
from ..npx.npx_types import DType, ElemType, TensorType, OptParType
from ._onnx_common import template_asarray
from . import _finalize_array_api

__all__ = [
Expand All@@ -45,36 +44,9 @@ def asarray(
"""
Converts anything into an array.
"""
if order not in ("C", None):
raise NotImplementedError(f"asarray is not implemented for order={order!r}.")
if like is not None:
raise NotImplementedError(
f"asarray is not implemented for like != None (type={type(like)})."
)
if isinstance(a, BaseArrayApi):
if copy:
if dtype is None:
return copy_inline(a)
return copy_inline(a).astype(dtype=dtype)
if dtype is None:
return a
return a.astype(dtype=dtype)

if isinstance(a, int):
v = EagerNumpyTensor(np.array(a, dtype=np.int64))
elif isinstance(a, float):
v = EagerNumpyTensor(np.array(a, dtype=np.float32))
elif isinstance(a, bool):
v = EagerNumpyTensor(np.array(a, dtype=np.bool_))
elif isinstance(a, str):
v = EagerNumpyTensor(np.array(a, dtype=np.str_))
else:
raise RuntimeError(f"Unexpected type {type(a)} for the first input.")
if dtype is not None:
vt = v.astype(dtype=dtype)
else:
vt = v
return vt
return template_asarray(
EagerNumpyTensor, a, dtype=dtype, order=order, like=like, copy=copy
)


def zeros(
Expand Down
20 changes: 20 additions & 0 deletionsonnx_array_api/array_api/onnx_ort.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
"""
Array API valid for an :class:`EagerOrtTensor`.
"""
from typing import Optional, Any
from ..ort.ort_tensors import EagerOrtTensor
from ..npx.npx_types import DType
from ..npx.npx_functions import (
all,
abs,
Expand All@@ -11,12 +14,14 @@
reshape,
take,
)
from ._onnx_common import template_asarray
from . import _finalize_array_api

__all__ = [
"all",
"abs",
"absolute",
"asarray",
"astype",
"equal",
"isdtype",
Expand All@@ -25,6 +30,21 @@
]


def asarray(
a: Any,
dtype: Optional[DType] = None,
order: Optional[str] = None,
like: Any = None,
copy: bool = False,
) -> EagerOrtTensor:
"""
Converts anything into an array.
"""
return template_asarray(
EagerOrtTensor, a, dtype=dtype, order=order, like=like, copy=copy
)


def _finalize():
from . import onnx_ort

Expand Down

[8]ページ先頭

©2009-2025 Movatter.jp