Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Add command line to replace constants in a model#87

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Merged
sdpython merged 7 commits intomainfromrepl
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes fromall commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletionsCHANGELOGS.rst
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -4,6 +4,7 @@ Change Logs
0.3.0
+++++

* :pr:`87`: add command line to replace contant by ConstantOfShape
* :pr:`79`: first draft to export to GraphBuilder
* :pr:`77`: supports ConcatOfShape and Slice with the light API

Expand Down
5 changes: 5 additions & 0 deletions_doc/api/tools.rst
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -6,6 +6,11 @@ Benchmark

.. autofunction:: onnx_array_api.ext_test_case.measure_time

Manipulations
+++++++++++++

.. autofunction:: onnx_array_api.tools.replace_constants.replace_initializer_by_constant_of_shape

Examples
++++++++

Expand Down
160 changes: 160 additions & 0 deletions_unittests/ut_tools/test_replace_constants.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
import unittest
import numpy as np
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
from onnx import TensorProto
from onnx_array_api.ext_test_case import ExtTestCase
from onnx_array_api.reference import (
ExtendedReferenceEvaluator as ReferenceEvaluator,
)
from onnx_array_api.tools.replace_constants import (
replace_initializer_by_constant_of_shape,
)


class TestReplaceConstants(ExtTestCase):

def test_replace_initializer(self):
dtype = np.float32
value = np.random.randn(2, 100).astype(dtype)
A = onh.from_array(value, name="A")
value = np.array([1], dtype=dtype)
C = onh.from_array(value, name="C")

X = oh.make_tensor_value_info("X", TensorProto.FLOAT, [None, None])
Y = oh.make_tensor_value_info("Y", TensorProto.FLOAT, [None])
node1 = oh.make_node("MatMul", ["X", "A"], ["AX"])
node2 = oh.make_node("Sub", ["AX", "C"], ["Y"])
graph = oh.make_graph([node1, node2], "lr", [X], [Y], [A, C])
model_def = oh.make_model(graph)

x = np.array([1, 2, 4, 5, 5, 4]).astype(np.float32).reshape((3, 2))
oinf1 = ReferenceEvaluator(model_def)
y1 = oinf1.run(None, {"X": x})[0] # type: ignore[index]
repl = replace_initializer_by_constant_of_shape(model_def)
node_types = {n.op_type for n in repl.graph.node}
self.assertIn("ConstantOfShape", node_types)
oinf2 = ReferenceEvaluator(repl)
y1[:, :] = 3.5
y1[0, :] = 0.5
y2 = oinf2.run(None, {"X": x})[0] # type: ignore[index]
self.assertEqualArray(y1, y2)

def test_replace_constant(self):
dtype = np.float32
value = np.random.randn(2, 10).astype(dtype)
A = onh.from_array(value, name="A")
value = np.array([1], dtype=dtype)
C = onh.from_array(value, name="C")

X = oh.make_tensor_value_info("X", TensorProto.FLOAT, [None, None])
Y = oh.make_tensor_value_info("Y", TensorProto.FLOAT, [None])
node0 = oh.make_node("Constant", [], ["A"], value=A)
node1 = oh.make_node("MatMul", ["X", "A"], ["AX"])
node2 = oh.make_node("Sub", ["AX", "C"], ["Y"])
graph = oh.make_graph([node0, node1, node2], "lr", [X], [Y], [C])
model_def = oh.make_model(graph)

x = np.array([1, 2, 4, 5, 5, 4]).astype(np.float32).reshape((3, 2))
oinf1 = ReferenceEvaluator(model_def)
y1 = oinf1.run(None, {"X": x})[0] # type: ignore[index]
repl = replace_initializer_by_constant_of_shape(model_def, threshold=0)
node_types = {n.op_type for n in repl.graph.node}
self.assertIn("ConstantOfShape", node_types)
oinf2 = ReferenceEvaluator(repl)
y1[:, :] = 4
y1[0, :] = 1
y2 = oinf2.run(None, {"X": x})[0] # type: ignore[index]
self.assertEqualArray(y1, y2)

def test_replace_constant_function(self):
dtype = np.float32
value = np.random.randn(2, 100).astype(dtype)
A = onh.from_array(value, name="A")
value = np.array([1], dtype=dtype)
C = onh.from_array(value, name="C")

X = oh.make_tensor_value_info("X", TensorProto.FLOAT, [None, None])
Y = oh.make_tensor_value_info("Y", TensorProto.FLOAT, [None])
nodeC = oh.make_node("Constant", [], ["C"], value=C)
node0 = oh.make_node("Constant", [], ["A"], value=A)
node1 = oh.make_node("MatMul", ["X", "A"], ["AX"])
node2 = oh.make_node("Sub", ["AX", "C"], ["Y"])
opset_imports = [
oh.make_opsetid("", onnx.defs.onnx_opset_version()),
oh.make_opsetid("custom", 1),
]
fct = oh.make_function(
"custom",
"unittest",
["X"],
["Y"],
[nodeC, node0, node1, node2],
opset_imports,
)

node = oh.make_node("unittest", ["X"], ["Y"], domain="custom")
graph = oh.make_graph([node], "lr", [X], [Y], [C])
model_def = oh.make_model(graph, functions=[fct], opset_imports=opset_imports)

x = np.array([1, 2, 4, 5, 5, 4]).astype(np.float32).reshape((3, 2))
oinf1 = ReferenceEvaluator(model_def)
y1 = oinf1.run(None, {"X": x})[0] # type: ignore[index]
repl = replace_initializer_by_constant_of_shape(model_def)
node_types = {n.op_type for n in repl.functions[0].node}
self.assertIn("ConstantOfShape", node_types)
oinf2 = ReferenceEvaluator(repl)
y1[:, :] = 3.5
y1[0, :] = 0.5
y2 = oinf2.run(None, {"X": x})[0] # type: ignore[index]
self.assertEqualArray(y1, y2)

def test_replace_constant_graph(self):
value = np.array([0], dtype=np.float32)
zero = onh.from_array(value, name="zero")

X = oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [None, None])
Y = oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [None])

rsum = oh.make_node("ReduceSum", ["X"], ["rsum"])
cond = oh.make_node("Greater", ["rsum", "zero"], ["cond"])

then_out = oh.make_tensor_value_info("then_out", onnx.TensorProto.FLOAT, None)
then_cst = onh.from_array(np.array([1] * 129).astype(np.float32))

then_const_node = oh.make_node(
"Constant", inputs=[], outputs=["then_out"], value=then_cst, name="cst1"
)
then_body = oh.make_graph([then_const_node], "then_body", [], [then_out])

else_out = oh.make_tensor_value_info("else_out", onnx.TensorProto.FLOAT, None)
else_cst = onh.from_array(np.array([-1] * 129).astype(np.float32))
else_const_node = oh.make_node(
"Constant", inputs=[], outputs=["else_out"], value=else_cst, name="cst2"
)
else_body = oh.make_graph([else_const_node], "else_body", [], [else_out])

if_node = oh.make_node(
"If", ["cond"], ["Y"], then_branch=then_body, else_branch=else_body
)
graph = oh.make_graph([rsum, cond, if_node], "if", [X], [Y], [zero])
onnx_model = oh.make_model(
graph, opset_imports=[oh.make_opsetid("", onnx.defs.onnx_opset_version())]
)
self.assertNotIn("ConstantOfShape", str(onnx_model))

x = np.ones((3, 2), dtype=np.float32)
oinf1 = ReferenceEvaluator(onnx_model)
y1 = oinf1.run(None, {"X": x})[0] # type: ignore[index]
repl = replace_initializer_by_constant_of_shape(onnx_model)
self.assertIn("ConstantOfShape", str(repl))
oinf2 = ReferenceEvaluator(repl)
y2 = oinf2.run(None, {"X": x})[0] # type: ignore[index]
y1 = y1.copy()
y1[:] = 0.5
self.assertEqualArray(y1, y2)


if __name__ == "__main__":
unittest.main(verbosity=2)
8 changes: 8 additions & 0 deletions_unittests/ut_xrun_doc/test_command_lines1.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -16,6 +16,7 @@
get_main_parser,
get_parser_compare,
get_parser_translate,
get_parser_replace,
main,
)

Expand All@@ -35,6 +36,13 @@ def test_parser_translate(self):
text = st.getvalue()
self.assertIn("model", text)

def test_parser_replace(self):
st = StringIO()
with redirect_stdout(st):
get_parser_replace().print_help()
text = st.getvalue()
self.assertIn("model", text)

def test_command_translate(self):
X = make_tensor_value_info("X", TensorProto.FLOAT, [None, None])
Y = make_tensor_value_info("Y", TensorProto.FLOAT, [5, 6])
Expand Down
80 changes: 76 additions & 4 deletionsonnx_array_api/_command_lines_parser.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -14,13 +14,14 @@ def get_main_parser() -> ArgumentParser:
)
parser.add_argument(
"cmd",
choices=["translate", "compare"],
choices=["translate", "compare", "replace"],
help=dedent(
"""
Selects a command.

'translate' exports an onnx graph into a piece of code replicating it,
'compare' compares the execution of two onnx models
'compare' compares the execution of two onnx models,
'replace' replaces constant and initliazers by ConstantOfShape to make the model lighter
"""
),
)
Expand DownExpand Up@@ -142,8 +143,75 @@ def _cmd_compare(argv: List[Any]):
print(text)


def get_parser_replace() -> ArgumentParser:
parser = ArgumentParser(
prog="translate",
description=dedent(
"""
Replaces constants and initializes by ConstOfShape or any other nodes
to make the model smaller.
"""
),
epilog="This is mostly used to write unit tests without adding "
"a big file to the repository.",
)
parser.add_argument(
"-m",
"--model",
type=str,
required=True,
help="onnx model to translate",
)
parser.add_argument(
"-o",
"--out",
type=str,
required=True,
help="output file",
)
parser.add_argument(
"-t",
"--threshold",
default=128,
help="Threshold above which every constant is replaced",
)
parser.add_argument(
"--type",
default="ConstontOfShape",
help="Inserts this operator type",
)
parser.add_argument(
"--domain",
default="",
help="Inserts this domain",
)
parser.add_argument(
"-v",
"--verbose",
default=0,
help="verbosity",
)
return parser


def _cmd_replace(argv: List[Any]):
from .tools.replace_constants import replace_initializer_by_constant_of_shape

parser = get_parser_replace()
args = parser.parse_args(argv[1:])
if args.verbose in ("1", 1, "True", True):
print(f"[compare] load model {args.model!r}")
onx = onnx.load(args.model)
new_onx = replace_initializer_by_constant_of_shape(
onx, threshold=args.threshold, op_type=args.type, domain=args.domain
)
if args.verbose in ("1", 1, "True", True):
print(f"[compare] save model {args.out!r}")
onnx.save(new_onx, args.out)


def main(argv: Optional[List[Any]] = None):
fcts = dict(translate=_cmd_translate, compare=_cmd_compare)
fcts = dict(translate=_cmd_translate, compare=_cmd_compare, replace=_cmd_replace)

if argv is None:
argv = sys.argv[1:]
Expand All@@ -152,7 +220,11 @@ def main(argv: Optional[List[Any]] = None):
parser = get_main_parser()
parser.parse_args(argv)
else:
parsers = dict(translate=get_parser_translate, compare=get_parser_compare)
parsers = dict(
translate=get_parser_translate,
compare=get_parser_compare,
replace=get_parser_replace,
)
cmd = argv[0]
if cmd not in parsers:
raise ValueError(
Expand Down
5 changes: 2 additions & 3 deletionsonnx_array_api/array_api/_onnx_common.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -46,14 +46,13 @@ def asarray(
dtype: Optional[DType] = None,
order: Optional[str] = None,
like: Any = None,
device: Optional[str] = None,
copy: bool = False,
) -> EagerTensor:
"""
Converts anything into an array.
"""
"""
Converts anything into an array.
"""
assert device is None, f"asarray not implemented yet for device={device!r}"
if order not in ("C", None):
raise NotImplementedError(f"asarray is not implemented for order={order!r}.")
if like is not None:
Expand Down
3 changes: 2 additions & 1 deletiononnx_array_api/npx/npx_functions.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -281,7 +281,8 @@ def astype(
to = DType(TensorProto.STRING)
else:
raise TypeError(f"dtype must of type DType, not {type(dtype)}-{dtype}.")
return var(a, op="Cast", to=to.code)
return var(a, op="Cast", to=to.code)
return var(a, op="Cast", to=dtype.code)


@npxapi_inline
Expand Down
1 change: 1 addition & 0 deletionsonnx_array_api/tools/__init__.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@

Loading
Loading

[8]ページ先頭

©2009-2025 Movatter.jp