Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit8aa1f28

Browse files
committed
2 parents32fc52e +01e0fac commit8aa1f28

File tree

9 files changed

+482
-8
lines changed

9 files changed

+482
-8
lines changed

‎CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.3.0
55
+++++
66

7+
*:pr:`87`: add command line to replace contant by ConstantOfShape
78
*:pr:`79`: first draft to export to GraphBuilder
89
*:pr:`77`: supports ConcatOfShape and Slice with the light API
910

‎_doc/api/tools.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@ Benchmark
66

77
..autofunction::onnx_array_api.ext_test_case.measure_time
88

9+
Manipulations
10+
+++++++++++++
11+
12+
..autofunction::onnx_array_api.tools.replace_constants.replace_initializer_by_constant_of_shape
13+
914
Examples
1015
++++++++
1116

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
importunittest
2+
importnumpyasnp
3+
importonnx
4+
importonnx.helperasoh
5+
importonnx.numpy_helperasonh
6+
fromonnximportTensorProto
7+
fromonnx_array_api.ext_test_caseimportExtTestCase
8+
fromonnx_array_api.referenceimport (
9+
ExtendedReferenceEvaluatorasReferenceEvaluator,
10+
)
11+
fromonnx_array_api.tools.replace_constantsimport (
12+
replace_initializer_by_constant_of_shape,
13+
)
14+
15+
16+
classTestReplaceConstants(ExtTestCase):
17+
18+
deftest_replace_initializer(self):
19+
dtype=np.float32
20+
value=np.random.randn(2,100).astype(dtype)
21+
A=onh.from_array(value,name="A")
22+
value=np.array([1],dtype=dtype)
23+
C=onh.from_array(value,name="C")
24+
25+
X=oh.make_tensor_value_info("X",TensorProto.FLOAT, [None,None])
26+
Y=oh.make_tensor_value_info("Y",TensorProto.FLOAT, [None])
27+
node1=oh.make_node("MatMul", ["X","A"], ["AX"])
28+
node2=oh.make_node("Sub", ["AX","C"], ["Y"])
29+
graph=oh.make_graph([node1,node2],"lr", [X], [Y], [A,C])
30+
model_def=oh.make_model(graph)
31+
32+
x=np.array([1,2,4,5,5,4]).astype(np.float32).reshape((3,2))
33+
oinf1=ReferenceEvaluator(model_def)
34+
y1=oinf1.run(None, {"X":x})[0]# type: ignore[index]
35+
repl=replace_initializer_by_constant_of_shape(model_def)
36+
node_types= {n.op_typeforninrepl.graph.node}
37+
self.assertIn("ConstantOfShape",node_types)
38+
oinf2=ReferenceEvaluator(repl)
39+
y1[:, :]=3.5
40+
y1[0, :]=0.5
41+
y2=oinf2.run(None, {"X":x})[0]# type: ignore[index]
42+
self.assertEqualArray(y1,y2)
43+
44+
deftest_replace_constant(self):
45+
dtype=np.float32
46+
value=np.random.randn(2,10).astype(dtype)
47+
A=onh.from_array(value,name="A")
48+
value=np.array([1],dtype=dtype)
49+
C=onh.from_array(value,name="C")
50+
51+
X=oh.make_tensor_value_info("X",TensorProto.FLOAT, [None,None])
52+
Y=oh.make_tensor_value_info("Y",TensorProto.FLOAT, [None])
53+
node0=oh.make_node("Constant", [], ["A"],value=A)
54+
node1=oh.make_node("MatMul", ["X","A"], ["AX"])
55+
node2=oh.make_node("Sub", ["AX","C"], ["Y"])
56+
graph=oh.make_graph([node0,node1,node2],"lr", [X], [Y], [C])
57+
model_def=oh.make_model(graph)
58+
59+
x=np.array([1,2,4,5,5,4]).astype(np.float32).reshape((3,2))
60+
oinf1=ReferenceEvaluator(model_def)
61+
y1=oinf1.run(None, {"X":x})[0]# type: ignore[index]
62+
repl=replace_initializer_by_constant_of_shape(model_def,threshold=0)
63+
node_types= {n.op_typeforninrepl.graph.node}
64+
self.assertIn("ConstantOfShape",node_types)
65+
oinf2=ReferenceEvaluator(repl)
66+
y1[:, :]=4
67+
y1[0, :]=1
68+
y2=oinf2.run(None, {"X":x})[0]# type: ignore[index]
69+
self.assertEqualArray(y1,y2)
70+
71+
deftest_replace_constant_function(self):
72+
dtype=np.float32
73+
value=np.random.randn(2,100).astype(dtype)
74+
A=onh.from_array(value,name="A")
75+
value=np.array([1],dtype=dtype)
76+
C=onh.from_array(value,name="C")
77+
78+
X=oh.make_tensor_value_info("X",TensorProto.FLOAT, [None,None])
79+
Y=oh.make_tensor_value_info("Y",TensorProto.FLOAT, [None])
80+
nodeC=oh.make_node("Constant", [], ["C"],value=C)
81+
node0=oh.make_node("Constant", [], ["A"],value=A)
82+
node1=oh.make_node("MatMul", ["X","A"], ["AX"])
83+
node2=oh.make_node("Sub", ["AX","C"], ["Y"])
84+
opset_imports= [
85+
oh.make_opsetid("",onnx.defs.onnx_opset_version()),
86+
oh.make_opsetid("custom",1),
87+
]
88+
fct=oh.make_function(
89+
"custom",
90+
"unittest",
91+
["X"],
92+
["Y"],
93+
[nodeC,node0,node1,node2],
94+
opset_imports,
95+
)
96+
97+
node=oh.make_node("unittest", ["X"], ["Y"],domain="custom")
98+
graph=oh.make_graph([node],"lr", [X], [Y], [C])
99+
model_def=oh.make_model(graph,functions=[fct],opset_imports=opset_imports)
100+
101+
x=np.array([1,2,4,5,5,4]).astype(np.float32).reshape((3,2))
102+
oinf1=ReferenceEvaluator(model_def)
103+
y1=oinf1.run(None, {"X":x})[0]# type: ignore[index]
104+
repl=replace_initializer_by_constant_of_shape(model_def)
105+
node_types= {n.op_typeforninrepl.functions[0].node}
106+
self.assertIn("ConstantOfShape",node_types)
107+
oinf2=ReferenceEvaluator(repl)
108+
y1[:, :]=3.5
109+
y1[0, :]=0.5
110+
y2=oinf2.run(None, {"X":x})[0]# type: ignore[index]
111+
self.assertEqualArray(y1,y2)
112+
113+
deftest_replace_constant_graph(self):
114+
value=np.array([0],dtype=np.float32)
115+
zero=onh.from_array(value,name="zero")
116+
117+
X=oh.make_tensor_value_info("X",onnx.TensorProto.FLOAT, [None,None])
118+
Y=oh.make_tensor_value_info("Y",onnx.TensorProto.FLOAT, [None])
119+
120+
rsum=oh.make_node("ReduceSum", ["X"], ["rsum"])
121+
cond=oh.make_node("Greater", ["rsum","zero"], ["cond"])
122+
123+
then_out=oh.make_tensor_value_info("then_out",onnx.TensorProto.FLOAT,None)
124+
then_cst=onh.from_array(np.array([1]*129).astype(np.float32))
125+
126+
then_const_node=oh.make_node(
127+
"Constant",inputs=[],outputs=["then_out"],value=then_cst,name="cst1"
128+
)
129+
then_body=oh.make_graph([then_const_node],"then_body", [], [then_out])
130+
131+
else_out=oh.make_tensor_value_info("else_out",onnx.TensorProto.FLOAT,None)
132+
else_cst=onh.from_array(np.array([-1]*129).astype(np.float32))
133+
else_const_node=oh.make_node(
134+
"Constant",inputs=[],outputs=["else_out"],value=else_cst,name="cst2"
135+
)
136+
else_body=oh.make_graph([else_const_node],"else_body", [], [else_out])
137+
138+
if_node=oh.make_node(
139+
"If", ["cond"], ["Y"],then_branch=then_body,else_branch=else_body
140+
)
141+
graph=oh.make_graph([rsum,cond,if_node],"if", [X], [Y], [zero])
142+
onnx_model=oh.make_model(
143+
graph,opset_imports=[oh.make_opsetid("",onnx.defs.onnx_opset_version())]
144+
)
145+
self.assertNotIn("ConstantOfShape",str(onnx_model))
146+
147+
x=np.ones((3,2),dtype=np.float32)
148+
oinf1=ReferenceEvaluator(onnx_model)
149+
y1=oinf1.run(None, {"X":x})[0]# type: ignore[index]
150+
repl=replace_initializer_by_constant_of_shape(onnx_model)
151+
self.assertIn("ConstantOfShape",str(repl))
152+
oinf2=ReferenceEvaluator(repl)
153+
y2=oinf2.run(None, {"X":x})[0]# type: ignore[index]
154+
y1=y1.copy()
155+
y1[:]=0.5
156+
self.assertEqualArray(y1,y2)
157+
158+
159+
if__name__=="__main__":
160+
unittest.main(verbosity=2)

‎_unittests/ut_xrun_doc/test_command_lines1.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
get_main_parser,
1717
get_parser_compare,
1818
get_parser_translate,
19+
get_parser_replace,
1920
main,
2021
)
2122

@@ -35,6 +36,13 @@ def test_parser_translate(self):
3536
text=st.getvalue()
3637
self.assertIn("model",text)
3738

39+
deftest_parser_replace(self):
40+
st=StringIO()
41+
withredirect_stdout(st):
42+
get_parser_replace().print_help()
43+
text=st.getvalue()
44+
self.assertIn("model",text)
45+
3846
deftest_command_translate(self):
3947
X=make_tensor_value_info("X",TensorProto.FLOAT, [None,None])
4048
Y=make_tensor_value_info("Y",TensorProto.FLOAT, [5,6])

‎onnx_array_api/_command_lines_parser.py

Lines changed: 76 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@ def get_main_parser() -> ArgumentParser:
1414
)
1515
parser.add_argument(
1616
"cmd",
17-
choices=["translate","compare"],
17+
choices=["translate","compare","replace"],
1818
help=dedent(
1919
"""
2020
Selects a command.
2121
2222
'translate' exports an onnx graph into a piece of code replicating it,
23-
'compare' compares the execution of two onnx models
23+
'compare' compares the execution of two onnx models,
24+
'replace' replaces constant and initliazers by ConstantOfShape to make the model lighter
2425
"""
2526
),
2627
)
@@ -142,8 +143,75 @@ def _cmd_compare(argv: List[Any]):
142143
print(text)
143144

144145

146+
defget_parser_replace()->ArgumentParser:
147+
parser=ArgumentParser(
148+
prog="translate",
149+
description=dedent(
150+
"""
151+
Replaces constants and initializes by ConstOfShape or any other nodes
152+
to make the model smaller.
153+
"""
154+
),
155+
epilog="This is mostly used to write unit tests without adding "
156+
"a big file to the repository.",
157+
)
158+
parser.add_argument(
159+
"-m",
160+
"--model",
161+
type=str,
162+
required=True,
163+
help="onnx model to translate",
164+
)
165+
parser.add_argument(
166+
"-o",
167+
"--out",
168+
type=str,
169+
required=True,
170+
help="output file",
171+
)
172+
parser.add_argument(
173+
"-t",
174+
"--threshold",
175+
default=128,
176+
help="Threshold above which every constant is replaced",
177+
)
178+
parser.add_argument(
179+
"--type",
180+
default="ConstontOfShape",
181+
help="Inserts this operator type",
182+
)
183+
parser.add_argument(
184+
"--domain",
185+
default="",
186+
help="Inserts this domain",
187+
)
188+
parser.add_argument(
189+
"-v",
190+
"--verbose",
191+
default=0,
192+
help="verbosity",
193+
)
194+
returnparser
195+
196+
197+
def_cmd_replace(argv:List[Any]):
198+
from .tools.replace_constantsimportreplace_initializer_by_constant_of_shape
199+
200+
parser=get_parser_replace()
201+
args=parser.parse_args(argv[1:])
202+
ifargs.verbosein ("1",1,"True",True):
203+
print(f"[compare] load model{args.model!r}")
204+
onx=onnx.load(args.model)
205+
new_onx=replace_initializer_by_constant_of_shape(
206+
onx,threshold=args.threshold,op_type=args.type,domain=args.domain
207+
)
208+
ifargs.verbosein ("1",1,"True",True):
209+
print(f"[compare] save model{args.out!r}")
210+
onnx.save(new_onx,args.out)
211+
212+
145213
defmain(argv:Optional[List[Any]]=None):
146-
fcts=dict(translate=_cmd_translate,compare=_cmd_compare)
214+
fcts=dict(translate=_cmd_translate,compare=_cmd_compare,replace=_cmd_replace)
147215

148216
ifargvisNone:
149217
argv=sys.argv[1:]
@@ -152,7 +220,11 @@ def main(argv: Optional[List[Any]] = None):
152220
parser=get_main_parser()
153221
parser.parse_args(argv)
154222
else:
155-
parsers=dict(translate=get_parser_translate,compare=get_parser_compare)
223+
parsers=dict(
224+
translate=get_parser_translate,
225+
compare=get_parser_compare,
226+
replace=get_parser_replace,
227+
)
156228
cmd=argv[0]
157229
ifcmdnotinparsers:
158230
raiseValueError(

‎onnx_array_api/array_api/_onnx_common.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,13 @@ def asarray(
4646
dtype:Optional[DType]=None,
4747
order:Optional[str]=None,
4848
like:Any=None,
49+
device:Optional[str]=None,
4950
copy:bool=False,
5051
)->EagerTensor:
5152
"""
5253
Converts anything into an array.
5354
"""
54-
"""
55-
Converts anything into an array.
56-
"""
55+
assertdeviceisNone,f"asarray not implemented yet for device={device!r}"
5756
ifordernotin ("C",None):
5857
raiseNotImplementedError(f"asarray is not implemented for order={order!r}.")
5958
iflikeisnotNone:

‎onnx_array_api/npx/npx_functions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,8 @@ def astype(
281281
to=DType(TensorProto.STRING)
282282
else:
283283
raiseTypeError(f"dtype must of type DType, not{type(dtype)}-{dtype}.")
284-
returnvar(a,op="Cast",to=to.code)
284+
returnvar(a,op="Cast",to=to.code)
285+
returnvar(a,op="Cast",to=dtype.code)
285286

286287

287288
@npxapi_inline

‎onnx_array_api/tools/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp