Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Add ExtendedReferenceEvaluator to test scenario outside onnx specifications#24

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Merged
sdpython merged 4 commits intomainfromexe
Jun 29, 2023
Merged
Show file tree
Hide file tree
Changes fromall commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletionsCHANGELOGS.rst
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -4,6 +4,7 @@ Change Logs
0.2.0
+++++

* :pr:`24`: add ExtendedReferenceEvaluator to support scenario for the Array API onnx does not support
* :pr:`22`: support OrtValue in function :func:`ort_profile`
* :pr:`17`: implements ArrayAPI
* :pr:`3`: fixes Array API with onnxruntime and scikit-learn
1 change: 1 addition & 0 deletions_doc/api/index.rst
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -15,4 +15,5 @@ API
onnx_tools
ort
plotting
reference
tools
7 changes: 7 additions & 0 deletions_doc/api/reference.rst
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
reference
=========

ExtendedReferenceEvaluator
++++++++++++++++++++++++++

.. autoclass:: onnx_array_api.reference.ExtendedReferenceEvaluator
2 changes: 0 additions & 2 deletions_unittests/onnx-numpy-skips.txt
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -9,6 +9,4 @@ array_api_tests/test_creation_functions.py::test_eye
array_api_tests/test_creation_functions.py::test_full_like
array_api_tests/test_creation_functions.py::test_linspace
array_api_tests/test_creation_functions.py::test_meshgrid
# Issue with CastLike and bfloat16 on onnx <= 1.15.0
# array_api_tests/test_creation_functions.py::test_ones_like
array_api_tests/test_creation_functions.py::test_zeros_like
7 changes: 1 addition & 6 deletions_unittests/ut_array_api/test_onnx_numpy.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
import sys
import unittest
from packaging.version import Version
import numpy as np
from onnx import TensorProto, __version__ as onnx_ver
from onnx import TensorProto
from onnx_array_api.ext_test_case import ExtTestCase
from onnx_array_api.array_api import onnx_numpy as xp
from onnx_array_api.npx.npx_types import DType
Expand DownExpand Up@@ -99,10 +98,6 @@ def test_arange_int00(self):
expected = expected.astype(np.int64)
self.assertEqualArray(matnp, expected)

@unittest.skipIf(
Version(onnx_ver) < Version("1.15.0"),
reason="Reference implementation of CastLike is bugged.",
)
def test_ones_like_uint16(self):
x = EagerTensor(np.array(0, dtype=np.uint16))
y = np.ones_like(x.numpy())
Expand Down
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
import os
import platform
import unittest
from typing import Any
import numpy
import onnx.backend.base
import onnx.backend.test
import onnx.shape_inference
import onnx.version_converter
from onnx import ModelProto
from onnx.backend.base import Device, DeviceType
from onnx.defs import onnx_opset_version
from onnx_array_api.reference import ExtendedReferenceEvaluator


class ExtendedReferenceEvaluatorBackendRep(onnx.backend.base.BackendRep):
def __init__(self, session):
self._session = session

def run(self, inputs, **kwargs):
if isinstance(inputs, numpy.ndarray):
inputs = [inputs]
if isinstance(inputs, list):
if len(inputs) == len(self._session.input_names):
feeds = dict(zip(self._session.input_names, inputs))
else:
feeds = {}
pos_inputs = 0
for inp, tshape in zip(
self._session.input_names, self._session.input_types
):
shape = tuple(d.dim_value for d in tshape.tensor_type.shape.dim)
if shape == inputs[pos_inputs].shape:
feeds[inp] = inputs[pos_inputs]
pos_inputs += 1
if pos_inputs >= len(inputs):
break
elif isinstance(inputs, dict):
feeds = inputs
else:
raise TypeError(f"Unexpected input type {type(inputs)!r}.")
outs = self._session.run(None, feeds)
return outs


class ExtendedReferenceEvaluatorBackend(onnx.backend.base.Backend):
@classmethod
def is_opset_supported(cls, model): # pylint: disable=unused-argument
return True, ""

@classmethod
def supports_device(cls, device: str) -> bool:
d = Device(device)
return d.type == DeviceType.CPU # type: ignore[no-any-return]

@classmethod
def create_inference_session(cls, model):
return ExtendedReferenceEvaluator(model)

@classmethod
def prepare(
cls, model: Any, device: str = "CPU", **kwargs: Any
) -> ExtendedReferenceEvaluatorBackendRep:
# if isinstance(model, ExtendedReferenceEvaluatorBackendRep):
# return model
if isinstance(model, ExtendedReferenceEvaluator):
return ExtendedReferenceEvaluatorBackendRep(model)
if isinstance(model, (str, bytes, ModelProto)):
inf = cls.create_inference_session(model)
return cls.prepare(inf, device, **kwargs)
raise TypeError(f"Unexpected type {type(model)} for model.")

@classmethod
def run_model(cls, model, inputs, device=None, **kwargs):
rep = cls.prepare(model, device, **kwargs)
return rep.run(inputs, **kwargs)

@classmethod
def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
raise NotImplementedError("Unable to run the model node by node.")


backend_test = onnx.backend.test.BackendTest(
ExtendedReferenceEvaluatorBackend, __name__
)

if os.getenv("APPVEYOR"):
backend_test.exclude("(test_vgg19|test_zfnet)")
if platform.architecture()[0] == "32bit":
backend_test.exclude("(test_vgg19|test_zfnet|test_bvlc_alexnet)")
if platform.system() == "Windows":
backend_test.exclude("test_sequence_model")

if onnx_opset_version() < 21:
backend_test.exclude(
"(test_averagepool_2d_dilations"
"|test_if*"
"|test_loop*"
"|test_scan*"
"|test_sequence_map*"
")"
)

if onnx_opset_version() < 19:
backend_test.exclude(
"(test_argm[ai][nx]_default_axis_example"
"|test_argm[ai][nx]_default_axis_random"
"|test_argm[ai][nx]_keepdims_example"
"|test_argm[ai][nx]_keepdims_random"
"|test_argm[ai][nx]_negative_axis_keepdims_example"
"|test_argm[ai][nx]_negative_axis_keepdims_random"
"|test_argm[ai][nx]_no_keepdims_example"
"|test_argm[ai][nx]_no_keepdims_random"
"|test_col2im_pads"
"|test_gru_batchwise"
"|test_gru_defaults"
"|test_gru_seq_length"
"|test_gru_with_initial_bias"
"|test_layer_normalization_2d_axis1_expanded"
"|test_layer_normalization_2d_axis_negative_1_expanded"
"|test_layer_normalization_3d_axis1_epsilon_expanded"
"|test_layer_normalization_3d_axis2_epsilon_expanded"
"|test_layer_normalization_3d_axis_negative_1_epsilon_expanded"
"|test_layer_normalization_3d_axis_negative_2_epsilon_expanded"
"|test_layer_normalization_4d_axis1_expanded"
"|test_layer_normalization_4d_axis2_expanded"
"|test_layer_normalization_4d_axis3_expanded"
"|test_layer_normalization_4d_axis_negative_1_expanded"
"|test_layer_normalization_4d_axis_negative_2_expanded"
"|test_layer_normalization_4d_axis_negative_3_expanded"
"|test_layer_normalization_default_axis_expanded"
"|test_logsoftmax_large_number_expanded"
"|test_lstm_batchwise"
"|test_lstm_defaults"
"|test_lstm_with_initial_bias"
"|test_lstm_with_peepholes"
"|test_mvn"
"|test_mvn_expanded"
"|test_softmax_large_number_expanded"
"|test_operator_reduced_mean"
"|test_operator_reduced_mean_keepdim)"
)

# The following tests are not supported.
backend_test.exclude(
"(test_gradient"
"|test_if_opt"
"|test_loop16_seq_none"
"|test_range_float_type_positive_delta_expanded"
"|test_range_int32_type_negative_delta_expanded"
"|test_scan_sum)"
)

if onnx_opset_version() < 21:
# The following tests are using types not supported by NumPy.
# They could be if method to_array is extended to support custom
# types the same as the reference implementation does
# (see onnx.reference.op_run.to_array_extended).
backend_test.exclude(
"(test_cast_FLOAT_to_BFLOAT16"
"|test_cast_BFLOAT16_to_FLOAT"
"|test_cast_BFLOAT16_to_FLOAT"
"|test_castlike_BFLOAT16_to_FLOAT"
"|test_castlike_FLOAT_to_BFLOAT16"
"|test_castlike_FLOAT_to_BFLOAT16_expanded"
"|test_cast_no_saturate_"
"|_to_FLOAT8"
"|_FLOAT8"
"|test_quantizelinear_e4m3fn"
"|test_quantizelinear_e5m2"
")"
)

# Disable test about float 8
backend_test.exclude(
"(test_castlike_BFLOAT16*"
"|test_cast_BFLOAT16*"
"|test_cast_no_saturate*"
"|test_cast_FLOAT_to_FLOAT8*"
"|test_cast_FLOAT16_to_FLOAT8*"
"|test_cast_FLOAT8_to_*"
"|test_castlike_BFLOAT16*"
"|test_castlike_no_saturate*"
"|test_castlike_FLOAT_to_FLOAT8*"
"|test_castlike_FLOAT16_to_FLOAT8*"
"|test_castlike_FLOAT8_to_*"
"|test_quantizelinear_e*)"
)

# The following tests are too slow with the reference implementation (Conv).
backend_test.exclude(
"(test_bvlc_alexnet"
"|test_densenet121"
"|test_inception_v1"
"|test_inception_v2"
"|test_resnet50"
"|test_shufflenet"
"|test_squeezenet"
"|test_vgg19"
"|test_zfnet512)"
)

# The following tests cannot pass because they consists in generating random number.
backend_test.exclude("(test_bernoulli)")

if onnx_opset_version() < 21:
# The following tests fail due to a bug in the backend test comparison.
backend_test.exclude(
"(test_cast_FLOAT_to_STRING|test_castlike_FLOAT_to_STRING|test_strnorm)"
)

# The following tests fail due to a shape mismatch.
backend_test.exclude(
"(test_center_crop_pad_crop_axes_hwc_expanded|test_lppool_2d_dilations)"
)

# The following tests fail due to a type mismatch.
backend_test.exclude("(test_eyelike_without_dtype)")

# The following tests fail due to discrepancies (small but still higher than 1e-7).
backend_test.exclude("test_adam_multiple") # 1e-2


# import all test cases at global scope to make them visible to python.unittest
globals().update(backend_test.test_cases)

if __name__ == "__main__":
res = unittest.main(verbosity=2, exit=False)
tests_run = res.result.testsRun
errors = len(res.result.errors)
skipped = len(res.result.skipped)
unexpected_successes = len(res.result.unexpectedSuccesses)
expected_failures = len(res.result.expectedFailures)
print("---------------------------------")
print(
f"tests_run={tests_run} errors={errors} skipped={skipped} "
f"unexpected_successes={unexpected_successes} "
f"expected_failures={expected_failures}"
)
8 changes: 4 additions & 4 deletionsonnx_array_api/npx/npx_numpy_tensors.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
from typing import Any, Callable, List, Optional, Tuple
import numpy as np
from onnx import ModelProto, TensorProto
fromonnx.reference importReferenceEvaluator
from..reference importExtendedReferenceEvaluator
from .._helpers import np_dtype_to_tensor_dtype
from .npx_numpy_tensors_ops import ConstantOfShape
from .npx_tensors import EagerTensor, JitTensor
Expand All@@ -11,15 +11,15 @@
class NumpyTensor:
"""
Default backend based on
:func:`onnx.reference.ReferenceEvaluator`.
:func:`onnx_array_api.reference.ExtendedReferenceEvaluator`.

:param input_names: input names
:param onx: onnx model
"""

class Evaluator:
"""
Wraps class :class:`onnx.reference.ReferenceEvaluator`
Wraps class :class:`onnx_array_api.reference.ExtendedReferenceEvaluator`
to have a signature closer to python function.

:param tensor_class: class tensor such as :class:`NumpyTensor`
Expand All@@ -35,7 +35,7 @@ def __init__(
onx: ModelProto,
f: Callable,
):
self.ref =ReferenceEvaluator(onx, new_ops=[ConstantOfShape])
self.ref =ExtendedReferenceEvaluator(onx, new_ops=[ConstantOfShape])
self.input_names = input_names
self.tensor_class = tensor_class
self._f = f
Expand Down
1 change: 1 addition & 0 deletionsonnx_array_api/reference/__init__.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
from .evaluator import ExtendedReferenceEvaluator
Loading

[8]ページ先頭

©2009-2025 Movatter.jp