Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Improves translation to GraphBuilder#95

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Merged
sdpython merged 6 commits intomainfrombugr
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes fromall commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions.github/workflows/check-urls.yml
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -42,6 +42,6 @@ jobs:
print_all: false
timeout: 2
retry_count# : 2
exclude_urls: https://hal.archives-ouvertes.fr/hal-00990252/document
exclude_patterns: https://www.data.gouv.fr/fr/datasets/r/e3d83ab3-dc52-4c99-abaf-8a38050cc68c,https://dev.azure.com/
exclude_urls: https://hal.archives-ouvertes.fr/hal-00990252/document,https://github.com/onnx/tensorflow-onnx
exclude_patterns: https://www.data.gouv.fr/fr/datasets/r/e3d83ab3-dc52-4c99-abaf-8a38050cc68c,https://dev.azure.com/,https://github.com/onnx/tensorflow-onnx
# force_pass : true
5 changes: 5 additions & 0 deletionsCHANGELOGS.rst
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
Change Logs
===========

0.3.1
+++++

* :pr:`95`: improves translation to GraphBuilder

0.3.0
+++++

Expand Down
67 changes: 62 additions & 5 deletions_unittests/ut_translate_api/test_translate_builder.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -8,7 +8,8 @@
from onnx_array_api.ext_test_case import ExtTestCase
from onnx_array_api.light_api import start
from onnx_array_api.graph_api import GraphBuilder
from onnx_array_api.translate_api import translate
from onnx_array_api.translate_api import translate, Translater
from onnx_array_api.translate_api.builder_emitter import BuilderEmitter


OPSET_API = min(19, onnx_opset_version() - 1)
Expand All@@ -19,7 +20,7 @@ def setUp(self):
self.maxDiff = None

def test_exp(self):
onx = start(opset=19).vin("X").Exp().rename("Y").vout().to_onnx()
onx = start(opset=19, ir_version=10).vin("X").Exp().rename("Y").vout().to_onnx()
self.assertIsInstance(onx, ModelProto)
self.assertIn("Exp", str(onx))
ref = ReferenceEvaluator(onx)
Expand All@@ -38,7 +39,7 @@ def light_api(
op.Identity(Y, outputs=["Y"])
return Y

g = GraphBuilder({'': 19})
g = GraphBuilder({'': 19}, ir_version=10)
g.make_tensor_input("X", TensorProto.FLOAT, ())
light_api(g.op, "X")
g.make_tensor_output("Y", TensorProto.FLOAT, ())
Expand DownExpand Up@@ -68,7 +69,7 @@ def light_api(

def test_zdoc(self):
onx = (
start(opset=19)
start(opset=19, ir_version=10)
.vin("X")
.reshape((-1, 1))
.Transpose(perm=[1, 0])
Expand All@@ -89,7 +90,7 @@ def light_api(
op.Identity(Y, outputs=["Y"])
return Y

g = GraphBuilder({'': 19})
g = GraphBuilder({'': 19}, ir_version=10)
g.make_tensor_input("X", TensorProto.FLOAT, ())
light_api(g.op, "X")
g.make_tensor_output("Y", TensorProto.FLOAT, ())
Expand DownExpand Up@@ -117,6 +118,62 @@ def light_api(
self.assertNotEmpty(model)
check_model(model)

def test_exp_f(self):
onx = start(opset=19, ir_version=10).vin("X").Exp().rename("Y").vout().to_onnx()
self.assertIsInstance(onx, ModelProto)
self.assertIn("Exp", str(onx))
ref = ReferenceEvaluator(onx)
a = np.arange(10).astype(np.float32)
got = ref.run(None, {"X": a})[0]
self.assertEqualArray(np.exp(a), got)

tr = Translater(onx, emitter=BuilderEmitter("mm"))
code = tr.export(as_str=True)

expected = dedent(
"""
def light_api(
op: "GraphBuilder",
X: "FLOAT[]",
):
Y = op.Exp(X)
op.Identity(Y, outputs=["Y"])
return Y


def mm() -> "ModelProto":
g = GraphBuilder({'': 19}, ir_version=10)
g.make_tensor_input("X", TensorProto.FLOAT, ())
light_api(g.op, "X")
g.make_tensor_output("Y", TensorProto.FLOAT, ())
model = g.to_onnx()
return model


model = mm()
"""
).strip("\n")
self.assertEqual(expected, code.strip("\n"))

def light_api(
op: "GraphBuilder",
X: "FLOAT[]", # noqa: F722
):
Y = op.Exp(X)
op.Identity(Y, outputs=["Y"])
return Y

g2 = GraphBuilder({"": 19})
g2.make_tensor_input("X", TensorProto.FLOAT, ("A",))
light_api(g2.op, "X")
g2.make_tensor_output("Y", TensorProto.FLOAT, ("A",))
onx2 = g2.to_onnx()

ref = ReferenceEvaluator(onx2)
a = np.arange(10).astype(np.float32)
got = ref.run(None, {"X": a})[0]
self.assertEqualArray(np.exp(a), got)


if __name__ == "__main__":
unittest.main(verbosity=2)
2 changes: 1 addition & 1 deletiononnx_array_api/__init__.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -2,5 +2,5 @@
APIs to create ONNX Graphs.
"""

__version__ = "0.3.0"
__version__ = "0.3.1"
__author__ = "Xavier Dupré"
60 changes: 52 additions & 8 deletionsonnx_array_api/translate_api/builder_emitter.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -4,10 +4,17 @@
from .base_emitter import BaseEmitter

_types = {
TensorProto.DOUBLE: "DOUBLE",
TensorProto.FLOAT: "FLOAT",
TensorProto.FLOAT16: "FLOAT16",
TensorProto.INT64: "INT64",
TensorProto.INT32: "INT32",
TensorProto.INT16: "INT16",
TensorProto.UINT64: "UINT64",
TensorProto.UINT32: "UINT32",
TensorProto.UINT16: "UINT16",
TensorProto.STRING: "STRING",
TensorProto.BOOL: "BOOL",
}


Expand All@@ -20,6 +27,10 @@ class BuilderEmitter(BaseEmitter):
Converts event into proper code.
"""

def __init__(self, make_model_function: str = ""):
super().__init__()
self.make_model_function = make_model_function

def join(self, rows: List[str], single_line: bool = False) -> str:
"Join the rows"
assert (
Expand All@@ -29,6 +40,7 @@ def join(self, rows: List[str], single_line: bool = False) -> str:

def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]:
self.opsets = kwargs.get("opsets", {})
self.ir_version = kwargs.get("ir_version", None)
return []

def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]:
Expand All@@ -43,12 +55,27 @@ def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]:
)
rows = [
"",
f"g = GraphBuilder({self.opsets})",
(
f"g = GraphBuilder({self.opsets}, ir_version={self.ir_version})"
if self.ir_version
else f"GraphBuilder({self.opsets})"
),
*inputs,
f"{self.name}({inps})",
*outputs,
"model = g.to_onnx()",
]
if self.make_model_function:
rows = [
"",
"",
f'def {self.make_model_function}() -> "ModelProto":',
*[" " + _ for _ in rows[1:]],
" return model",
"",
"",
f"model = {self.make_model_function}()",
]
return rows

def _emit_begin_graph(self, **kwargs: Dict[str, Any]) -> List[str]:
Expand DownExpand Up@@ -78,13 +105,16 @@ def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]:
name = kwargs["name"]
itype = kwargs.get("elem_type", 0)
shape = kwargs.get("shape", None)
name = self._clean_result_name(name)
if itype == 0:
inp = "X"
inp =name or"X"
else:
if shape is None:
inp = f'X: "{_itype_to_string(itype)}"'
inp = f'{name}: "{_itype_to_string(itype)}"'
else:
inp = f'X: "{_itype_to_string(itype)}[{", ".join(map(str, shape))}]"'
inp = (
f'{name}: "{_itype_to_string(itype)}[{", ".join(map(str, shape))}]"'
)
self.inputs_full.append(inp)
self.inputs.append(name)
self.inputs_full_.append((name, _itype_to_string(itype), shape))
Expand DownExpand Up@@ -113,6 +143,7 @@ def _emit_end_return(self, **kwargs: Dict[str, Any]) -> List[str]:

def _emit_output(self, **kwargs: Dict[str, Any]) -> List[str]:
name = kwargs["name"]
name = self._clean_result_name(name)
itype = kwargs.get("elem_type", 0)
shape = kwargs.get("shape", None)
self.outputs.append(name)
Expand All@@ -126,6 +157,8 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]:
if kwargs.get("domain", "") != "":
domain = kwargs["domain"]
op_type = f"{domain}.{op_type}"
else:
domain = ""
atts = kwargs.get("atts", {})
args = []
for k, v in atts.items():
Expand All@@ -134,11 +167,22 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]:
raise NotImplementedError("Graph attribute not supported yet.")
args.append(f"{k}={vatt}")

outs = ", ".join(outputs)
inps = ", ".join(inputs)
outs = ", ".join(map(self._clean_result_name, outputs))
inps = ", ".join(map(self._clean_result_name, inputs))
op_type = self._emit_node_type(op_type, domain)
sdomain = "" if not domain else f", domain={domain!r}"
if args:
sargs = ", ".join(args)
row = f" {outs} = op.{op_type}({inps}, {sargs})"
if inps:
row = f" {outs} = op.{op_type}({inps}, {sargs}{sdomain})"
else:
row = f" {outs} = op.{op_type}({sargs}{sdomain})"
else:
row = f" {outs} = op.{op_type}({inps})"
row = f" {outs} = op.{op_type}({inps}{sdomain})"
return [row]

def _clean_result_name(self, name):
return name

def _emit_node_type(self, op_type, domain):
return op_type
6 changes: 5 additions & 1 deletiononnx_array_api/translate_api/translate.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -35,7 +35,11 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
last_event = None
if isinstance(self.proto_, ModelProto):
opsets = {d.domain: d.version for d in self.proto_.opset_import}
rows.extend(self.emitter(EventType.START, opsets=opsets))
rows.extend(
self.emitter(
EventType.START, opsets=opsets, ir_version=self.proto_.ir_version
)
)
inputs = self.proto_.graph.input
outputs = self.proto_.graph.output
nodes = self.proto_.graph.node
Expand Down
Loading

[8]ページ先頭

©2009-2025 Movatter.jp