Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Extend ExtendedReferenceEvaluator#75

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Merged
sdpython merged 10 commits intomainfromref2
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes fromall commits
Commits
Show all changes
10 commits
Select commitHold shift + click to select a range
6eb6adf
update requirements
xadupreFeb 2, 2024
4f0a994
Merge branch 'main' of https://github.com/sdpython/onnx-array-api
xadupreFeb 2, 2024
c7bb055
Merge branch 'main' of https://github.com/sdpython/onnx-array-api
xadupreFeb 2, 2024
a3d4ccf
Merge branch 'main' of https://github.com/sdpython/onnx-array-api
xadupreFeb 6, 2024
7ed1385
Merge branch 'main' of https://github.com/sdpython/onnx-array-api
xadupreFeb 6, 2024
bab2a6b
Merge branch 'main' of https://github.com/sdpython/onnx-array-api
xadupreFeb 7, 2024
014404b
Merge branch 'main' of https://github.com/sdpython/onnx-array-api
xadupreFeb 7, 2024
00e2a1c
Merge branch 'main' of https://github.com/sdpython/onnx-array-api
xadupreFeb 14, 2024
228ff67
add more operator to the reference evaluator
xadupreFeb 14, 2024
5f37a59
extend unit test copverage
xadupreFeb 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletionsCHANGELOGS.rst
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -4,6 +4,7 @@ Change Logs
0.2.0
+++++

* :pr:`75`: add QuickGelu to ExtendedReferenceEvaluator
* :pr:`71`: adds tools to compare two onnx graphs
* :pr:`61`: adds function to plot onnx model as graphs
* :pr:`60`: supports translation of local functions
Expand Down
82 changes: 82 additions & 0 deletions_unittests/ut_reference/test_reference_ops.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -59,6 +59,88 @@ def test_fused_matmul11(self):
got = ref.run(None, {"X": a, "Y": a})
self.assertEqualArray(a.T @ a.T, got[0])

def test_memcpy(self):
model = make_model(
make_graph(
[
make_node("MemcpyToHost", ["X"], ["Z"]),
make_node("MemcpyFromHost", ["X"], ["Z"]),
],
"name",
[make_tensor_value_info("X", TensorProto.FLOAT, None)],
[make_tensor_value_info("Z", TensorProto.FLOAT, None)],
),
opset_imports=[make_opsetid("", 18), make_opsetid("com.microsoft", 1)],
ir_version=9,
)
a = np.arange(4).reshape(-1, 2).astype(np.float32)
ref = ExtendedReferenceEvaluator(model)
got = ref.run(None, {"X": a})
self.assertEqualArray(a, got[0])

def test_quick_gelu(self):
from onnxruntime import InferenceSession

for alpha in [0.0, 2.0]:
model = make_model(
make_graph(
[
make_node(
"QuickGelu",
["X"],
["Z"],
domain="com.microsoft",
alpha=alpha,
)
],
"name",
[make_tensor_value_info("X", TensorProto.FLOAT, None)],
[make_tensor_value_info("Z", TensorProto.FLOAT, None)],
),
opset_imports=[make_opsetid("", 18), make_opsetid("com.microsoft", 1)],
ir_version=9,
)
sess = InferenceSession(
model.SerializeToString(), providers=["CPUExecutionProvider"]
)
a = np.arange(4).reshape(-1, 2).astype(np.float32)
expected = sess.run(None, {"X": a})
ref = ExtendedReferenceEvaluator(model)
got = ref.run(None, {"X": a})
self.assertEqualArray(expected[0], got[0])

def test_scatter_elements(self):
model = make_model(
make_graph(
[
make_node(
"ScatterElements",
["data", "indices", "updates"],
["Z"],
axis=3,
reduction="add",
)
],
"name",
[
make_tensor_value_info("data", TensorProto.FLOAT, None),
make_tensor_value_info("indices", TensorProto.INT64, None),
make_tensor_value_info("updates", TensorProto.FLOAT, None),
],
[make_tensor_value_info("Z", TensorProto.FLOAT, None)],
),
opset_imports=[make_opsetid("", 18)],
)
data = np.zeros(2**4, dtype=np.float32).reshape((2, 2, 2, 2))
indices = np.array([[[[0]]]], dtype=np.int64)
updates = np.array([[[[1]]]], dtype=np.float32)
y = np.array(
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=np.float32
).reshape((2, 2, 2, 2))
ref = ExtendedReferenceEvaluator(model)
got = ref.run(None, {"data": data, "indices": indices, "updates": updates})
self.assertEqualArray(y, got[0])


if __name__ == "__main__":
unittest.main(verbosity=2)
7 changes: 7 additions & 0 deletionsonnx_array_api/reference/evaluator.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -8,6 +8,9 @@
from .ops.op_concat import Concat
from .ops.op_constant_of_shape import ConstantOfShape
from .ops.op_fused_matmul import FusedMatMul
from .ops.op_memcpy_host import MemcpyFromHost, MemcpyToHost
from .ops.op_quick_gelu import QuickGelu
from .ops.op_scatter_elements import ScatterElements


logger = getLogger("onnx-array-api-eval")
Expand All@@ -34,6 +37,10 @@ class ExtendedReferenceEvaluator(ReferenceEvaluator):
CastLike_19,
ConstantOfShape,
FusedMatMul,
MemcpyFromHost,
MemcpyToHost,
QuickGelu,
ScatterElements,
]

@staticmethod
Expand Down
11 changes: 11 additions & 0 deletionsonnx_array_api/reference/ops/op_memcpy_host.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
from onnx.reference.op_run import OpRun


class MemcpyFromHost(OpRun):
def _run(self, x):
return (x,)


class MemcpyToHost(OpRun):
def _run(self, x):
return (x,)
23 changes: 23 additions & 0 deletionsonnx_array_api/reference/ops/op_quick_gelu.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
import numpy as np
from onnx.reference.op_run import OpRun


def sigmoid(x): # type: ignore
if x > 0:
return 1 / (1 + np.exp(-x))
return np.exp(x) / (1 + np.exp(x))


class QuickGelu(OpRun):
op_domain = "com.microsoft"

def __init__(self, onnx_node, run_params): # type: ignore
OpRun.__init__(self, onnx_node, run_params)
self.vf = np.vectorize(sigmoid)

def _run(self, X, alpha=1.0):
if len(X.shape) == 0:
return ((X * sigmoid(X * alpha)).astype(X.dtype),)
if X.size == 0:
return (X,)
return ((X * self.vf(X * alpha)).astype(X.dtype),)
98 changes: 98 additions & 0 deletionsonnx_array_api/reference/ops/op_scatter_elements.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
import numpy as np

from onnx.reference.op_run import OpRun


def scatter_elements(data, indices, updates, axis=0, reduction=None): # type: ignore
if reduction == "add":

def f(x, y):
return x + y

elif reduction == "min":

def f(x, y):
return min(x, y)

elif reduction == "max":

def f(x, y):
return max(x, y)

else:

def f(x, y):
return y

if axis < 0:
axis = data.ndim + axis

if len(data.shape) == 1 and axis == 0:
scattered = np.copy(data)
for pos, up in zip(indices, updates):
scattered[pos] = f(scattered[pos], up)
return scattered

if len(indices.shape) == 2:
scattered = np.copy(data)
if axis == 0:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
scattered[indices[i, j], j] = f(
scattered[indices[i, j], j], updates[i, j]
)
else:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
scattered[i, indices[i, j]] = f(
scattered[i, indices[i, j]], updates[i, j]
)
return scattered

if len(indices.shape) == 3:
scattered = np.copy(data)
if axis == 0:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
for k in range(indices.shape[2]):
scattered[indices[i, j, k], j, k] = f(
scattered[indices[i, j, k], j, k], updates[i, j, k]
)
elif axis == 1:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
for k in range(indices.shape[2]):
scattered[i, indices[i, j, k], k] = f(
scattered[i, indices[i, j, k], k], updates[i, j, k]
)
elif axis == 2:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
for k in range(indices.shape[2]):
scattered[i, j, indices[i, j, k]] = f(
scattered[i, j, indices[i, j, k]], updates[i, j, k]
)
return scattered

if len(indices.shape) == 4:
scattered = np.copy(data)
if axis == 3:
for a in range(indices.shape[0]):
for i in range(indices.shape[1]):
for j in range(indices.shape[2]):
for k in range(indices.shape[3]):
scattered[a, i, j, indices[a, i, j, k]] = f(
scattered[a, i, j, indices[a, i, j, k]],
updates[a, i, j, k],
)
return scattered

raise RuntimeError(
f"Not implemented for indices.shape={indices.shape} and axis={axis}"
)


class ScatterElements(OpRun):
def _run(self, data, indices, updates, axis=None, reduction=None): # type: ignore
res = scatter_elements(data, indices, updates, axis=axis, reduction=reduction)
return (res,)

[8]ページ先頭

©2009-2025 Movatter.jp