Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit32ad385

Browse files
authored
Fixes Array API with onnxruntime (#3)
* Check Array API with onnxruntime* better error message* improvment* disable one test for older version of sklearn* add one more pipeline* fix pipeline* fix array api* remove unnecessary code* disable one test one the current scikit-learn version
1 parent062b6c1 commit32ad385

18 files changed

+265
-68
lines changed

‎CHANGELOGS.rst‎

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
Change Logs
2+
===========
3+
4+
0.2.0
5+
+++++
6+
7+
*:pr:`3`: fixes Array API with onnxruntime

‎_doc/conf.py‎

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@
5757
}
5858

5959
epkg_dictionary= {
60+
"Array API":"https://data-apis.org/array-api/",
61+
"ArrayAPI": (
62+
"https://data-apis.org/array-api/",
63+
("2022.12/API_specification/generated/array_api.{0}.html",1),
64+
),
6065
"DOT":"https://graphviz.org/doc/info/lang.html",
6166
"JIT":"https://en.wikipedia.org/wiki/Just-in-time_compilation",
6267
"onnx":"https://onnx.ai/onnx/",
@@ -65,7 +70,7 @@
6570
"numpy":"https://numpy.org/",
6671
"numba":"https://numba.pydata.org/",
6772
"onnx-array-api": (
68-
"http://www.xavierdupre.fr/app/""onnx-array-api/helpsphinx/index.html"
73+
"http://www.xavierdupre.fr/app/onnx-array-api/helpsphinx/index.html"
6974
),
7075
"pyinstrument":"https://github.com/joerick/pyinstrument",
7176
"python":"https://www.python.org/",

‎_doc/index.rst‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ well as to execute it.
3434
tutorial/index
3535
api/index
3636
auto_examples/index
37+
../CHANGELOGS
3738

3839
Sources available on
3940
`github/onnx-array-api<https://github.com/sdpython/onnx-array-api>`_,

‎_unittests/ut_npx/test_sklearn_array_api.py‎

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
importunittest
22
importnumpyasnp
3+
frompackaging.versionimportVersion
34
fromonnx.defsimportonnx_opset_version
4-
fromsklearnimportconfig_context
5+
fromsklearnimportconfig_context,__version__assklearn_version
56
fromsklearn.discriminant_analysisimportLinearDiscriminantAnalysis
67
fromonnx_array_api.ext_test_caseimportExtTestCase
78
fromonnx_array_api.npx.npx_numpy_tensorsimportEagerNumpyTensor
@@ -10,23 +11,15 @@
1011
DEFAULT_OPSET=onnx_opset_version()
1112

1213

13-
deftake(self,X,indices,*,axis):
14-
# Overwritting method take as it is using iterators.
15-
# When array_api supports `take` we can use this directly
16-
# https://github.com/data-apis/array-api/issues/177
17-
X_np=self._namespace.take(X,indices,axis=axis)
18-
returnself._namespace.asarray(X_np)
19-
20-
2114
classTestSklearnArrayAPI(ExtTestCase):
15+
@unittest.skipIf(
16+
Version(sklearn_version)<=Version("1.2.2"),
17+
reason="reshape ArrayAPI not followed",
18+
)
2219
deftest_sklearn_array_api_linear_discriminant(self):
23-
fromsklearn.utils._array_apiimport_ArrayAPIWrapper
24-
25-
_ArrayAPIWrapper.take=take
2620
X=np.array([[-1,-1], [-2,-1], [-3,-2], [1,1], [2,1], [3,2]])
2721
y=np.array([1,1,1,2,2,2])
2822
ana=LinearDiscriminantAnalysis()
29-
ana=LinearDiscriminantAnalysis()
3023
ana.fit(X,y)
3124
expected=ana.predict(X)
3225

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
importunittest
2+
importnumpyasnp
3+
frompackaging.versionimportVersion
4+
fromonnx.defsimportonnx_opset_version
5+
fromsklearnimportconfig_context,__version__assklearn_version
6+
fromsklearn.discriminant_analysisimportLinearDiscriminantAnalysis
7+
fromonnx_array_api.ext_test_caseimportExtTestCase
8+
fromonnx_array_api.ort.ort_tensorsimportEagerOrtTensor,OrtTensor
9+
10+
11+
DEFAULT_OPSET=onnx_opset_version()
12+
13+
14+
classTestSklearnArrayAPIOrt(ExtTestCase):
15+
@unittest.skipIf(
16+
Version(sklearn_version)<=Version("1.2.2"),
17+
reason="reshape ArrayAPI not followed",
18+
)
19+
deftest_sklearn_array_api_linear_discriminant(self):
20+
X=np.array([[-1,-1], [-2,-1], [-3,-2], [1,1], [2,1], [3,2]])
21+
y=np.array([1,1,1,2,2,2])
22+
ana=LinearDiscriminantAnalysis()
23+
ana.fit(X,y)
24+
expected=ana.predict(X)
25+
26+
new_x=EagerOrtTensor(OrtTensor.from_array(X))
27+
self.assertEqual(new_x.device_name,"Cpu")
28+
self.assertStartsWith(
29+
"EagerOrtTensor(OrtTensor.from_array(array([[",repr(new_x)
30+
)
31+
withconfig_context(array_api_dispatch=True):
32+
got=ana.predict(new_x)
33+
self.assertEqualArray(expected,got.numpy())
34+
35+
36+
if__name__=="__main__":
37+
# import logging
38+
# logging.basicConfig(level=logging.DEBUG)
39+
unittest.main(verbosity=2)

‎azure-pipelines.yml‎

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,53 @@ jobs:
4343
artifactName:'wheel-linux-wheel-$(python.version)'
4444
targetPath:'dist'
4545

46+
-job:'TestLinuxNightly'
47+
pool:
48+
vmImage:'ubuntu-latest'
49+
strategy:
50+
matrix:
51+
Python310-Linux:
52+
python.version:'3.11'
53+
maxParallel:3
54+
55+
steps:
56+
-task:UsePythonVersion@0
57+
inputs:
58+
versionSpec:'$(python.version)'
59+
architecture:'x64'
60+
-script:sudo apt-get update
61+
displayName:'AptGet Update'
62+
-script:sudo apt-get install -y pandoc
63+
displayName:'Install Pandoc'
64+
-script:sudo apt-get install -y inkscape
65+
displayName:'Install Inkscape'
66+
-script:sudo apt-get install -y graphviz
67+
displayName:'Install Graphviz'
68+
-script:python -m pip install --upgrade pip setuptools wheel
69+
displayName:'Install tools'
70+
-script:pip install -r requirements.txt
71+
displayName:'Install Requirements'
72+
-script:pip install -r requirements-dev.txt
73+
displayName:'Install Requirements dev'
74+
-script:pip uninstall -y scikit-learn
75+
displayName:'Uninstall scikit-learn'
76+
-script:pip install --pre --extra-index https://pypi.anaconda.org/scipy-wheels-nightly/simple scikit-learn
77+
displayName:'Install scikit-learn nightly'
78+
-script:pip install onnxmltools --no-deps
79+
displayName:'Install onnxmltools'
80+
-script:|
81+
ruff .
82+
displayName: 'Ruff'
83+
-script:|
84+
rstcheck -r ./_doc ./onnx_array_api
85+
displayName: 'rstcheck'
86+
-script:|
87+
black --diff .
88+
displayName: 'Black'
89+
-script:|
90+
python -m pytest -v
91+
displayName: 'Runs Unit Tests'
92+
4693
-job:'TestLinux'
4794
pool:
4895
vmImage:'ubuntu-latest'

‎onnx_array_api/npx/npx_array_api.py‎

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,34 @@
1-
fromtypingimportAny
1+
fromtypingimportAny,Optional
22

33
importnumpyasnp
44

55
from .npx_typesimportOptParType,ParType,TupleType
66

77

8+
classArrayApiError(RuntimeError):
9+
"""
10+
Raised when a function is not supported by the :epkg:`Array API`.
11+
"""
12+
13+
pass
14+
15+
816
classArrayApi:
917
"""
1018
List of supported method by a tensor.
1119
"""
1220

13-
def__array_namespace__(self):
21+
def__array_namespace__(self,api_version:Optional[str]=None):
1422
"""
1523
Returns the module holding all the available functions.
1624
"""
17-
fromonnx_array_api.npximportnpx_functions
25+
ifapi_versionisNoneorapi_version=="2022.12":
26+
fromonnx_array_api.npximportnpx_functions
1827

19-
returnnpx_functions
28+
returnnpx_functions
29+
raiseValueError(
30+
f"Unable to return an implementation for api_version={api_version!r}."
31+
)
2032

2133
defgeneric_method(self,method_name,*args:Any,**kwargs:Any)->Any:
2234
raiseNotImplementedError(

‎onnx_array_api/npx/npx_core_api.py‎

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,3 +252,10 @@ def npxapi_inline(fn):
252252
to call.
253253
"""
254254
return_xapi(fn,inline=True)
255+
256+
257+
defnpxapi_no_inline(fn):
258+
"""
259+
Functions decorated with this decorator are not converted into ONNX.
260+
"""
261+
returnfn

‎onnx_array_api/npx/npx_functions.py‎

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
fromtypingimportAny,Optional,Tuple,Union
22

3+
importarray_api_compat.numpyasnp_array_api
34
importnumpyasnp
45
fromonnximportFunctionProto,ModelProto,NodeProto,TensorProto
56
fromonnx.helperimportnp_dtype_to_tensor_dtype
67
fromonnx.numpy_helperimportfrom_array
78

89
from .npx_constantsimportFUNCTION_DOMAIN
9-
from .npx_core_apiimportcst,make_tuple,npxapi_inline,var
10+
from .npx_core_apiimportcst,make_tuple,npxapi_inline,npxapi_no_inline,var
1011
from .npx_tensorsimportArrayApi
1112
from .npx_typesimport (
13+
DType,
1214
ElemType,
1315
OptParType,
1416
ParType,
@@ -397,6 +399,17 @@ def identity(n: ParType[int], dtype=None) -> TensorType[ElemType.numerics, "T"]:
397399
returnv
398400

399401

402+
@npxapi_no_inline
403+
defisdtype(
404+
dtype:DType,kind:Union[DType,str,Tuple[Union[DType,str], ...]]
405+
)->bool:
406+
"""
407+
See :epkg:`ArrayAPI:isdtype`.
408+
This function is not converted into an onnx graph.
409+
"""
410+
returnnp_array_api.isdtype(dtype,kind)
411+
412+
400413
@npxapi_inline
401414
defisnan(x:TensorType[ElemType.numerics,"T"])->TensorType[ElemType.bool_,"T"]:
402415
"See :func:`numpy.isnan`."
@@ -460,9 +473,23 @@ def relu(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.numerics,
460473

461474
@npxapi_inline
462475
defreshape(
463-
x:TensorType[ElemType.numerics,"T"],shape:TensorType[ElemType.int64,"I"]
476+
x:TensorType[ElemType.numerics,"T"],
477+
shape:TensorType[ElemType.int64,"I", (None,)],
464478
)->TensorType[ElemType.numerics,"T"]:
465-
"See :func:`numpy.reshape`."
479+
"""
480+
See :func:`numpy.reshape`.
481+
482+
.. warning::
483+
484+
Numpy definition is tricky because onnxruntime does not handle well
485+
dimensions with an undefined number of dimensions.
486+
However the array API defines a more stricly signature for
487+
`reshape <https://data-apis.org/array-api/2022.12/
488+
API_specification/generated/array_api.reshape.html>`_.
489+
:epkg:`scikit-learn` updated its code to follow the Array API in
490+
`PR 26030 ENH Forces shape to be tuple when using Array API's reshape
491+
<https://github.com/scikit-learn/scikit-learn/pull/26030>`_.
492+
"""
466493
ifisinstance(shape,int):
467494
shape=cst(np.array([shape],dtype=np.int64))
468495
shape_reshaped=var(shape,cst(np.array([-1],dtype=np.int64)),op="Reshape")

‎onnx_array_api/npx/npx_graph_builder.py‎

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -798,6 +798,23 @@ def to_onnx(
798798
node_inputs.append(input_name)
799799
continue
800800

801+
ifisinstance(i,tuple)andall(map(lambdax:isinstance(x,int),i)):
802+
ai=np.array(list(i),dtype=np.int64)
803+
c=Cst(ai)
804+
input_name=self._unique(var._prefix)
805+
self._id_vars[id(i),index]=input_name
806+
self._id_vars[id(c),index]=input_name
807+
self.make_node(
808+
"Constant",
809+
[],
810+
[input_name],
811+
value=from_array(ai),
812+
opset=self.target_opsets[""],
813+
)
814+
self.onnx_names_[input_name]=c
815+
node_inputs.append(input_name)
816+
continue
817+
801818
raiseNotImplementedError(
802819
f"Unexpected type{type(i)} for node={domop}."
803820
)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp