Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Supports for local functions in translator#96

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Merged
sdpython merged 6 commits intomainfromlf
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes fromall commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletionsCHANGELOGS.rst
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -4,6 +4,7 @@ Change Logs
0.3.1
+++++

* :pr:`96`: supports local functions in translator
* :pr:`95`: improves translation to GraphBuilder

0.3.0
Expand Down
144 changes: 125 additions & 19 deletions_unittests/ut_translate_api/test_translate_builder.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
import unittest
from textwrap import dedent
import numpy as np
import onnx.helper as oh
from onnx import ModelProto, TensorProto
from onnx.checker import check_model
from onnx.defs import onnx_opset_version
Expand DownExpand Up@@ -29,37 +30,43 @@ def test_exp(self):
self.assertEqualArray(np.exp(a), got)

code = translate(onx, api="builder")
expected = dedent(
"""
expected = (
dedent(
"""
def light_api(
op: "GraphBuilder",
X: "FLOAT[]",
):
Y = op.Exp(X)
Y = op.Exp(X, outputs=['Y'])
op.Identity(Y, outputs=["Y"])
return Y

g = GraphBuilder({'': 19}, ir_version=10)
g.make_tensor_input("X", TensorProto.FLOAT, ())
light_api(g.op, "X")
g.make_tensor_output("Y", TensorProto.FLOAT, ())
g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__)
model = g.to_onnx()
"""
).strip("\n")
)
.strip("\n")
.replace("__SUFFIX__", ", is_dimension=False, indexed=False")
)
self.assertEqual(expected, code.strip("\n"))

def light_api(
op: "GraphBuilder",
X: "FLOAT[]", # noqa: F722
):
Y = op.Exp(X)
Y = op.Exp(X, outputs=["Y"])
op.Identity(Y, outputs=["Y"])
return Y

g2 = GraphBuilder({"": 19})
g2.make_tensor_input("X", TensorProto.FLOAT, ("A",))
light_api(g2.op, "X")
g2.make_tensor_output("Y", TensorProto.FLOAT, ("A",))
g2.make_tensor_output(
"Y", TensorProto.FLOAT, ("A",), is_dimension=False, indexed=False
)
onx2 = g2.to_onnx()

ref = ReferenceEvaluator(onx2)
Expand All@@ -78,25 +85,29 @@ def test_zdoc(self):
.to_onnx()
)
code = translate(onx, api="builder")
expected = dedent(
"""
expected = (
dedent(
"""
def light_api(
op: "GraphBuilder",
X: "FLOAT[]",
):
r = np.array([-1, 1], dtype=np.int64)
r0_0 = op.Reshape(X, r)
Y = op.Transpose(r0_0, perm=[1, 0])
r0_0 = op.Reshape(X, r, outputs=['r0_0'])
Y = op.Transpose(r0_0, perm=[1, 0], outputs=['Y'])
op.Identity(Y, outputs=["Y"])
return Y

g = GraphBuilder({'': 19}, ir_version=10)
g.make_tensor_input("X", TensorProto.FLOAT, ())
light_api(g.op, "X")
g.make_tensor_output("Y", TensorProto.FLOAT, ())
g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__)
model = g.to_onnx()
"""
).strip("\n")
)
.strip("\n")
.replace("__SUFFIX__", ", is_dimension=False, indexed=False")
)
self.maxDiff = None
self.assertEqual(expected, code.strip("\n"))

Expand DownExpand Up@@ -130,13 +141,14 @@ def test_exp_f(self):
tr = Translater(onx, emitter=BuilderEmitter("mm"))
code = tr.export(as_str=True)

expected = dedent(
"""
expected = (
dedent(
"""
def light_api(
op: "GraphBuilder",
X: "FLOAT[]",
):
Y = op.Exp(X)
Y = op.Exp(X, outputs=['Y'])
op.Identity(Y, outputs=["Y"])
return Y

Expand All@@ -145,14 +157,17 @@ def mm() -> "ModelProto":
g = GraphBuilder({'': 19}, ir_version=10)
g.make_tensor_input("X", TensorProto.FLOAT, ())
light_api(g.op, "X")
g.make_tensor_output("Y", TensorProto.FLOAT, ())
g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__)
model = g.to_onnx()
return model


model = mm()
"""
).strip("\n")
)
.strip("\n")
.replace("__SUFFIX__", ", is_dimension=False, indexed=False")
)
self.assertEqual(expected, code.strip("\n"))

def light_api(
Expand All@@ -166,14 +181,105 @@ def light_api(
g2 = GraphBuilder({"": 19})
g2.make_tensor_input("X", TensorProto.FLOAT, ("A",))
light_api(g2.op, "X")
g2.make_tensor_output("Y", TensorProto.FLOAT, ("A",))
g2.make_tensor_output(
"Y", TensorProto.FLOAT, ("A",), is_dimension=False, indexed=False
)
onx2 = g2.to_onnx()

ref = ReferenceEvaluator(onx2)
a = np.arange(10).astype(np.float32)
got = ref.run(None, {"X": a})[0]
self.assertEqualArray(np.exp(a), got)

def test_local_function(self):
new_domain = "custom"

linear_regression = oh.make_function(
new_domain,
"LinearRegression",
["x", "a", "b"],
["y"],
[
oh.make_node("MatMul", ["x", "a"], ["xa"]),
oh.make_node("Add", ["xa", "b"], ["y"]),
],
[oh.make_opsetid("", 14)],
[],
)

graph = oh.make_graph(
[
oh.make_node(
"LinearRegression", ["X", "A", "B"], ["Y1"], domain=new_domain
),
oh.make_node("Abs", ["Y1"], ["Y"]),
],
"example",
[
oh.make_tensor_value_info("X", TensorProto.FLOAT, [None, None]),
oh.make_tensor_value_info("A", TensorProto.FLOAT, [None, None]),
oh.make_tensor_value_info("B", TensorProto.FLOAT, [None, None]),
],
[oh.make_tensor_value_info("Y", TensorProto.FLOAT, None)],
)

onnx_model = oh.make_model(
graph,
opset_imports=[oh.make_opsetid("", 14), oh.make_opsetid(new_domain, 1)],
functions=[linear_regression],
ir_version=10,
)
tr = Translater(onnx_model, emitter=BuilderEmitter("mm"))
code = tr.export(as_str=True)

expected = (
dedent(
"""
def example(
op: "GraphBuilder",
X: "FLOAT[, ]",
A: "FLOAT[, ]",
B: "FLOAT[, ]",
):
Y1 = op.LinearRegression(X, A, B, domain='custom', outputs=['Y1'])
Y = op.Abs(Y1, outputs=['Y'])
op.Identity(Y, outputs=["Y"])
return Y


def make_custom_LinearRegression(g: "GraphBuilder"):
gr = GraphBuilder({'': 14}, as_function=True)
x = gr.make_tensor_input('x')
a = gr.make_tensor_input('a')
b = gr.make_tensor_input('b')
op = gr.op
xa = op.MatMul(x, a, outputs=['xa'])
y = op.Add(xa, b, outputs=['y'])
gr.make_tensor_output(y)
g.add_function(builder=gr)
return gr


def mm() -> "ModelProto":
g = GraphBuilder({'': 14, 'custom': 1}, ir_version=10)
g.make_tensor_input("X", TensorProto.FLOAT, ('', ''))
g.make_tensor_input("A", TensorProto.FLOAT, ('', ''))
g.make_tensor_input("B", TensorProto.FLOAT, ('', ''))
example(g.op, "X", "A", "B")
g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__)
make_custom_LinearRegression(g)
model = g.to_onnx()
return model


model = mm()
"""
)
.strip("\n")
.replace("__SUFFIX__", ", is_dimension=False, indexed=False")
)
self.assertEqual(expected, code.strip("\n"))


if __name__ == "__main__":
unittest.main(verbosity=2)
13 changes: 13 additions & 0 deletionsonnx_array_api/graph_api/graph_builder.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -194,6 +194,7 @@ def __init__(
self._known_shapes = {}
self._known_types = {}
self.constants_ = {}
self.functions_ = {}
elif isinstance(target_opset_or_existing_proto, ModelProto):
assert (
not input_names
Expand DownExpand Up@@ -223,6 +224,8 @@ def __init__(
self.constants_[node.output[0]] = node
self.set_shape(node.output[0], self._get_tensor_shape(node))
self.set_type(node.output[0], self._get_tensor_type(node))
for f in proto.functions:
self.add_function(f)
else:
raise NotImplementedError(
f"{type(target_opset_or_existing_proto)} is not supported."
Expand All@@ -231,6 +234,14 @@ def __init__(
self.op = Opset(self, self.opsets[""]) if "" in self.opsets else None
self._cache_array = []

def add_local_function(self, domain: str, name: str, gr: "GraphBuilder"):
"Adds a local function."
assert (
domain,
name,
) not in self.functions_, f"Function {(domain, name)} was already added."
self.functions_[domain, name] = gr

def _get_tensor_shape(
self, proto: Union[NodeProto, TensorProto]
) -> Tuple[int, ...]:
Expand DownExpand Up@@ -417,6 +428,8 @@ def make_tensor_output(
name: Union[str, List[str]],
elem_type: Optional[int] = None,
shape: Optional[Tuple[int, ...]] = None,
is_dimension: bool = False,
indexed: bool = False,
) -> Union[str, List[str]]:
if isinstance(name, list):
res = []
Expand Down
28 changes: 28 additions & 0 deletionsonnx_array_api/translate_api/base_emitter.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -25,6 +25,10 @@ class EventType(IntEnum):
END_SIGNATURE = 16
BEGIN_RETURN = 17
END_RETURN = 18
BEGIN_FUNCTION_SIGNATURE = 19
END_FUNCTION_SIGNATURE = 20
BEGIN_FUNCTION_RETURN = 21
END_FUNCTION_RETURN = 22

@classmethod
def to_str(cls, self) -> str:
Expand DownExpand Up@@ -76,6 +80,12 @@ def __call__(self, event: EventType, **kwargs: Dict[str, Any]) -> List[str]:
if event == EventType.BEGIN_FUNCTION:
return self._emit_begin_function(**kwargs)

if event == EventType.BEGIN_FUNCTION_SIGNATURE:
return self._emit_begin_function_signature(**kwargs)

if event == EventType.END_FUNCTION_SIGNATURE:
return self._emit_end_function_signature(**kwargs)

if event == EventType.END_FUNCTION:
return self._emit_end_function(**kwargs)

Expand All@@ -100,6 +110,12 @@ def __call__(self, event: EventType, **kwargs: Dict[str, Any]) -> List[str]:
if event == EventType.END_RETURN:
return self._emit_end_return(**kwargs)

if event == EventType.BEGIN_FUNCTION_RETURN:
return self._emit_begin_function_return(**kwargs)

if event == EventType.END_FUNCTION_RETURN:
return self._emit_end_function_return(**kwargs)

raise ValueError(f"Unexpected event {EventType.to_str(event)}.")

def render_attribute_value(self, value: Any) -> Tuple[List[str], str]:
Expand DownExpand Up@@ -224,6 +240,12 @@ def _emit_begin_function(self, **kwargs: Dict[str, Any]) -> List[str]:
f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded."
)

def _emit_begin_function_signature(self, **kwargs: Dict[str, Any]) -> List[str]:
return []

def _emit_end_function_signature(self, **kwargs: Dict[str, Any]) -> List[str]:
return []

def _emit_function_input(self, **kwargs: Dict[str, Any]) -> List[str]:
raise NotImplementedError(
f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded."
Expand All@@ -250,3 +272,9 @@ def _emit_begin_return(self, **kwargs: Dict[str, Any]) -> List[str]:

def _emit_end_return(self, **kwargs: Dict[str, Any]) -> List[str]:
return []

def _emit_begin_function_return(self, **kwargs: Dict[str, Any]) -> List[str]:
return []

def _emit_end_function_return(self, **kwargs: Dict[str, Any]) -> List[str]:
return []
Loading
Loading

[8]ページ先頭

©2009-2025 Movatter.jp