Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

First draft to export to GraphBuilder#83

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Merged
sdpython merged 7 commits intomainfromexpo
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes fromall commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletionCHANGELOGS.rst
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
Change Logs
===========

0.2.0
0.3.0
+++++

* :pr:`79`: first draft to export to GraphBuilder
* :pr:`77`: supports ConcatOfShape and Slice with the light API

0.2.0
+++++

* :pr:`76`, :pr:`79`: add a mode to compare models without execution
* :pr:`75`: add QuickGelu to ExtendedReferenceEvaluator
* :pr:`71`: adds tools to compare two onnx graphs
Expand Down
1 change: 0 additions & 1 deletion_unittests/ut_translate_api/test_translate.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -221,5 +221,4 @@ def test_aionnxml(self):


if __name__ == "__main__":
TestTranslate().test_export_if()
unittest.main(verbosity=2)
122 changes: 122 additions & 0 deletions_unittests/ut_translate_api/test_translate_builder.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
import unittest
from textwrap import dedent
import numpy as np
from onnx import ModelProto, TensorProto
from onnx.checker import check_model
from onnx.defs import onnx_opset_version
from onnx.reference import ReferenceEvaluator
from onnx_array_api.ext_test_case import ExtTestCase
from onnx_array_api.light_api import start
from onnx_array_api.graph_api import GraphBuilder
from onnx_array_api.translate_api import translate


OPSET_API = min(19, onnx_opset_version() - 1)


class TestTranslateBuilder(ExtTestCase):
def setUp(self):
self.maxDiff = None

def test_exp(self):
onx = start(opset=19).vin("X").Exp().rename("Y").vout().to_onnx()
self.assertIsInstance(onx, ModelProto)
self.assertIn("Exp", str(onx))
ref = ReferenceEvaluator(onx)
a = np.arange(10).astype(np.float32)
got = ref.run(None, {"X": a})[0]
self.assertEqualArray(np.exp(a), got)

code = translate(onx, api="builder")
expected = dedent(
"""
def light_api(
op: "GraphBuilder",
X: "FLOAT[]",
):
Y = op.Exp(X)
op.Identity(Y, outputs=["Y"])
return Y

g = GraphBuilder({'': 19})
g.make_tensor_input("X", TensorProto.FLOAT, ())
light_api(g.op, "X")
g.make_tensor_output("Y", TensorProto.FLOAT, ())
model = g.to_onnx()
"""
).strip("\n")
self.assertEqual(expected, code.strip("\n"))

def light_api(
op: "GraphBuilder",
X: "FLOAT[]", # noqa: F722
):
Y = op.Exp(X)
op.Identity(Y, outputs=["Y"])
return Y

g2 = GraphBuilder({"": 19})
g2.make_tensor_input("X", TensorProto.FLOAT, ("A",))
light_api(g2.op, "X")
g2.make_tensor_output("Y", TensorProto.FLOAT, ("A",))
onx2 = g2.to_onnx()

ref = ReferenceEvaluator(onx2)
a = np.arange(10).astype(np.float32)
got = ref.run(None, {"X": a})[0]
self.assertEqualArray(np.exp(a), got)

def test_zdoc(self):
onx = (
start(opset=19)
.vin("X")
.reshape((-1, 1))
.Transpose(perm=[1, 0])
.rename("Y")
.vout()
.to_onnx()
)
code = translate(onx, api="builder")
expected = dedent(
"""
def light_api(
op: "GraphBuilder",
X: "FLOAT[]",
):
r = np.array([-1, 1], dtype=np.int64)
r0_0 = op.Reshape(X, r)
Y = op.Transpose(r0_0, perm=[1, 0])
op.Identity(Y, outputs=["Y"])
return Y

g = GraphBuilder({'': 19})
g.make_tensor_input("X", TensorProto.FLOAT, ())
light_api(g.op, "X")
g.make_tensor_output("Y", TensorProto.FLOAT, ())
model = g.to_onnx()
"""
).strip("\n")
self.maxDiff = None
self.assertEqual(expected, code.strip("\n"))

def light_api(
op: "GraphBuilder",
X: "FLOAT[]", # noqa: F722
):
r = np.array([-1, 1], dtype=np.int64)
r0_0 = op.Reshape(X, r)
Y = op.Transpose(r0_0, perm=[1, 0])
op.Identity(Y, outputs=["Y"])
return Y

g = GraphBuilder({"": 21})
X = g.make_tensor_input("X", TensorProto.FLOAT, ())
light_api(g.op, X)
g.make_tensor_output("Y", TensorProto.FLOAT, ())
model = g.to_onnx()
self.assertNotEmpty(model)
check_model(model)


if __name__ == "__main__":
unittest.main(verbosity=2)
12 changes: 12 additions & 0 deletionsonnx_array_api/graph_api/graph_builder.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -119,6 +119,18 @@ def __getattr__(self, name):
except AttributeError as e:
raise AttributeError(f"Unable to access attribute {name!r}.") from e

def Initializer(
self, init: Union[TensorProto, np.ndarray], name: Optional[str] = None
) -> str:
"""
Creates an initializer.

:param init: value
:param name: name if value is not a TensorProto
:return: its name
"""
return self.builder.make_initializer(init, name=name, exists=True)

def make_node(
self,
op_type: str,
Expand Down
30 changes: 28 additions & 2 deletionsonnx_array_api/translate_api/__init__.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
from onnx import ModelProto
from .translate import Translater
from .inner_emitter import InnerEmitter
from .builder_emitter import BuilderEmitter


def translate(proto: ModelProto, single_line: bool = False, api: str = "light") -> str:
Expand All@@ -14,7 +15,8 @@ def translate(proto: ModelProto, single_line: bool = False, api: str = "light")
default is `"light"` and this is handle by class
:class:`onnx_array_api.translate_api.light_emitter.LightEmitter`,
another value is `"onnx"` which is the inner API implemented
in onnx package.
in onnx package, `"builder"` follows the syntax for the
class :class:`onnx_array_api.graph_api.GraphBuilder`
:return: code

.. runpython::
Expand All@@ -35,7 +37,7 @@ def translate(proto: ModelProto, single_line: bool = False, api: str = "light")
code = translate(onx)
print(code)

The inner API from onnxpackahe is also available.
The inner API from onnxpackage is also available.

.. runpython::
:showcode:
Expand All@@ -54,11 +56,35 @@ def translate(proto: ModelProto, single_line: bool = False, api: str = "light")
)
code = translate(onx, api="onnx")
print(code)

The :class:`GraphBuilder
<onnx_array_api.graph_api.GraphBuilder>` API returns this:

.. runpython::
:showcode:

from onnx_array_api.light_api import start
from onnx_array_api.translate_api import translate

onx = (
start()
.vin("X")
.reshape((-1, 1))
.Transpose(perm=[1, 0])
.rename("Y")
.vout()
.to_onnx()
)
code = translate(onx, api="builder")
print(code)
"""
if api == "light":
tr = Translater(proto)
return tr.export(single_line=single_line, as_str=True)
if api == "onnx":
tr = Translater(proto, emitter=InnerEmitter())
return tr.export(as_str=True)
if api == "builder":
tr = Translater(proto, emitter=BuilderEmitter())
return tr.export(as_str=True)
raise ValueError(f"Unexpected value {api!r} for api.")
28 changes: 28 additions & 0 deletionsonnx_array_api/translate_api/base_emitter.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -21,6 +21,10 @@ class EventType(IntEnum):
FUNCTION_OUTPUT = 12
FUNCTION_ATTRIBUTES = 13
TO_ONNX_FUNCTION = 14
BEGIN_SIGNATURE = 15
END_SIGNATURE = 16
BEGIN_RETURN = 17
END_RETURN = 18

@classmethod
def to_str(cls, self) -> str:
Expand DownExpand Up@@ -84,6 +88,18 @@ def __call__(self, event: EventType, **kwargs: Dict[str, Any]) -> List[str]:
if event == EventType.FUNCTION_ATTRIBUTES:
return self._emit_function_attributes(**kwargs)

if event == EventType.BEGIN_SIGNATURE:
return self._emit_begin_signature(**kwargs)

if event == EventType.END_SIGNATURE:
return self._emit_end_signature(**kwargs)

if event == EventType.BEGIN_RETURN:
return self._emit_begin_return(**kwargs)

if event == EventType.END_RETURN:
return self._emit_end_return(**kwargs)

raise ValueError(f"Unexpected event {EventType.to_str(event)}.")

def render_attribute_value(self, value: Any) -> Tuple[List[str], str]:
Expand DownExpand Up@@ -222,3 +238,15 @@ def _emit_function_attributes(self, **kwargs: Dict[str, Any]) -> List[str]:
raise NotImplementedError(
f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded."
)

def _emit_begin_signature(self, **kwargs: Dict[str, Any]) -> List[str]:
return []

def _emit_end_signature(self, **kwargs: Dict[str, Any]) -> List[str]:
return []

def _emit_begin_return(self, **kwargs: Dict[str, Any]) -> List[str]:
return []

def _emit_end_return(self, **kwargs: Dict[str, Any]) -> List[str]:
return []
Loading

[8]ページ先頭

©2009-2025 Movatter.jp