Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit2c739f9

Browse files
committed
improve code coverage
1 parentf513e4b commit2c739f9

File tree

2 files changed

+175
-28
lines changed

2 files changed

+175
-28
lines changed

‎_unittests/ut_graph_api/test_graph_builder.py‎

Lines changed: 140 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@
33
importunittest
44
importnumpyasnp
55
importonnx
6-
fromonnx.referenceimportReferenceEvaluator
76
fromonnx_array_api.ext_test_caseimportExtTestCase
8-
fromonnx_array_api.graph_api.graph_builderimportGraphBuilder
7+
fromonnx_array_api.graph_api.graph_builderimportGraphBuilder,OptimizationOptions
8+
fromonnx_array_api.referenceimport (
9+
from_array_extended,
10+
ExtendedReferenceEvaluatorasReferenceEvaluator,
11+
)
912

1013

1114
classTestGraphBuilder(ExtTestCase):
@@ -130,6 +133,35 @@ def test_constant_folding(self):
130133
got=ref.run(None,feeds)
131134
self.assertEqualArray(expected,got[0])
132135

136+
deftest_constant_folding2(self):
137+
g=GraphBuilder(
138+
optimization_options=OptimizationOptions(constant_folding=True)
139+
)
140+
141+
shape= (10,4)
142+
w=np.random.randn(*shape).astype(np.float32)
143+
x=g.make_tensor_input("X",np.float32,shape)
144+
weight=g.make_initializer(w)
145+
cst=g.get_constant(weight)
146+
self.assertEqualArray(w,cst)
147+
one=g.make_initializer(np.array([-1,1],dtype=np.int64))
148+
transposed=g.make_node("Transpose", [weight],perm=[1,0])
149+
res=g.op.MatMul(x,transposed)
150+
g.op.Reshape(res,one,outputs="y")
151+
g.make_tensor_output("y",np.float32, (10,1))
152+
153+
g.optimize()
154+
155+
onx=g.to_onnx()
156+
node_types= [n.op_typeforninonx.graph.node]
157+
self.assertNotIn("Transpose",node_types)
158+
ref=ReferenceEvaluator(onx)
159+
x=np.random.randn(*shape).astype(np.float32)
160+
expected= (x @w.T).reshape((-1,1))
161+
feeds= {"X":x}
162+
got=ref.run(None,feeds)
163+
self.assertEqualArray(expected,got[0])
164+
133165
deftest_remove_identity(self):
134166
withcontextlib.redirect_stdout(io.StringIO()):
135167
g=GraphBuilder(verbose=10)
@@ -238,6 +270,112 @@ def test_remove_unused_nodes_simple(self):
238270
got=ref.run(None,feeds)
239271
self.assertEqualArray(expected,got[0])
240272

273+
deftest_constant_array(self):
274+
withcontextlib.redirect_stdout(io.StringIO()):
275+
g=GraphBuilder(verbose=10)
276+
277+
shape= (10,4)
278+
w=np.random.randn(*shape).astype(np.float32)
279+
280+
x=g.make_tensor_input("X",np.float32,shape)
281+
one=g.make_initializer(np.array([-1,1],dtype=np.int64))
282+
res=g.op.MatMul(x,w.T)
283+
g.op.Reshape(res,one,outputs="y")
284+
g.make_tensor_output("y",np.float32, (10,1))
285+
onx=g.to_onnx()
286+
ref=ReferenceEvaluator(onx)
287+
x=np.random.randn(*shape).astype(np.float32)
288+
expected= (x @w.T).reshape((-1,1))
289+
feeds= {"X":x}
290+
got=ref.run(None,feeds)
291+
self.assertEqualArray(expected,got[0])
292+
293+
deftest_constant_array_2(self):
294+
withcontextlib.redirect_stdout(io.StringIO()):
295+
g=GraphBuilder(verbose=10)
296+
297+
shape= (10,4)
298+
w=np.random.randn(*shape).astype(np.float32)
299+
300+
x=g.make_tensor_input("X",np.float32,shape)
301+
one=g.make_initializer(np.array([-1,1],dtype=np.int64))
302+
opc=g.op.Constant(value=from_array_extended(w.T))
303+
res=g.op.MatMul(x,opc)
304+
g.op.Reshape(res,one,outputs="y")
305+
g.make_tensor_output("y",np.float32, (10,1))
306+
self.assertTrue(g.has_shape("X"))
307+
self.assertTrue(g.has_type("X"))
308+
self.assertEqual(g.get_type("X"),1)
309+
self.assertEqual(g.get_shape("X"), (10,4))
310+
self.assertEqual(g.rank("X"),2)
311+
onx=g.to_onnx()
312+
ref=ReferenceEvaluator(onx)
313+
x=np.random.randn(*shape).astype(np.float32)
314+
expected= (x @w.T).reshape((-1,1))
315+
feeds= {"X":x}
316+
got=ref.run(None,feeds)
317+
self.assertEqualArray(expected,got[0])
318+
319+
deftest_get_type(self):
320+
g=GraphBuilder()
321+
self.assertEqual(g._get_type(np.float32),onnx.TensorProto.FLOAT)
322+
self.assertEqual(g._get_type(np.int64),onnx.TensorProto.INT64)
323+
self.assertEqual(g._get_type(None),onnx.TensorProto.UNDEFINED)
324+
325+
deftest_make_nodes_prefix(self):
326+
g1=GraphBuilder()
327+
g1.make_tensor_input("X",np.float32,shape=None)
328+
g1.op.Add("X",np.array([1],dtype=np.float32),outputs=["y"])
329+
g1.make_tensor_output("y",np.float32,shape=None)
330+
331+
g=GraphBuilder()
332+
333+
shape= (10,4)
334+
w=np.random.randn(*shape).astype(np.float32)
335+
336+
x=g.make_tensor_input("X",np.float32,shape)
337+
weight=g.make_initializer(w)
338+
one=g.make_initializer(np.array([-1,1],dtype=np.int64))
339+
transposed=g.make_node("Transpose", [weight],perm=[1,0])
340+
res=g.op.MatMul(x,transposed)
341+
res2=g.make_nodes(g1, [res], ["k"],prefix="J")
342+
g.op.Reshape(res2,one,outputs="y")
343+
g.make_tensor_output("y",np.float32, (10,1))
344+
onx=g.to_onnx()
345+
ref=ReferenceEvaluator(onx)
346+
x=np.random.randn(*shape).astype(np.float32)
347+
expected= (x @w.T).reshape((-1,1))+1
348+
feeds= {"X":x}
349+
got=ref.run(None,feeds)
350+
self.assertEqualArray(expected,got[0])
351+
352+
deftest_make_nodes_noprefix(self):
353+
g1=GraphBuilder()
354+
g1.make_tensor_input("X",np.float32,shape=None)
355+
g1.op.Add("X",np.array([1],dtype=np.float32),outputs=["y"])
356+
g1.make_tensor_output("y",np.float32,shape=None)
357+
358+
g=GraphBuilder()
359+
360+
shape= (10,4)
361+
w=np.random.randn(*shape).astype(np.float32)
362+
363+
x=g.make_tensor_input("X",np.float32,shape)
364+
weight=g.make_initializer(w)
365+
one=g.make_initializer(np.array([-1,1],dtype=np.int64))
366+
transposed=g.make_node("Transpose", [weight],perm=[1,0])
367+
res=g.op.MatMul(x,transposed)
368+
res2=g.make_nodes(g1, [res], ["k"])
369+
g.op.Reshape(res2,one,outputs="y")
370+
g.make_tensor_output("y",np.float32, (10,1))
371+
onx=g.to_onnx()
372+
ref=ReferenceEvaluator(onx)
373+
x=np.random.randn(*shape).astype(np.float32)
374+
expected= (x @w.T).reshape((-1,1))+1
375+
feeds= {"X":x}
376+
got=ref.run(None,feeds)
377+
self.assertEqualArray(expected,got[0])
378+
241379

242380
if__name__=="__main__":
243381
unittest.main(verbosity=2)

‎onnx_array_api/graph_api/graph_builder.py‎

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,18 @@
1818
T="TENSOR"
1919

2020

21+
classOptimizationOptions:
22+
def__init__(
23+
self,
24+
remove_unused:bool=True,
25+
constant_folding:bool=False,
26+
constant_size:int=1024,
27+
):
28+
self.remove_unused=remove_unused
29+
self.constant_folding=constant_folding
30+
self.constant_size=constant_size
31+
32+
2133
classOpset:
2234
# defined for opset >= 18
2335
# name: number of expected outputs
@@ -76,7 +88,7 @@ def make_node(
7688
foriininputs:
7789
ifnotisinstance(i,str):
7890
name=self.builder.unique_name("cst")
79-
self.builder.make_initializer(i,name=name)
91+
self.builder.make_initializer(i,name=name,exists=True)
8092
new_inputs.append(name)
8193
else:
8294
new_inputs.append(i)
@@ -86,18 +98,6 @@ def make_node(
8698
)
8799

88100

89-
classOptimizationOptions:
90-
def__init__(
91-
self,
92-
remove_unused:bool=True,
93-
constant_folding:bool=False,
94-
constant_size:int=1024,
95-
):
96-
self.remove_unused=remove_unused
97-
self.constant_folding=constant_folding
98-
self.constant_size=constant_size
99-
100-
101101
classGraphBuilder:
102102
def__init__(
103103
self,
@@ -304,12 +304,18 @@ def _get_type(self, elem_type: Any, exc: bool = True) -> int:
304304
returnelem_type
305305

306306
defmake_initializer(
307-
self,value:Any,name:str="",external:bool=False
307+
self,value:Any,name:str="",external:bool=False,exists:bool=False
308308
)->str:
309309
ifexternal:
310310
raiseNotImplementedError("External initializers are not implemented yet.")
311311
ifname=="":
312+
ifexists:
313+
raiseValueError("Undefined name cannot exist.")
312314
name=self.unique_name("cst")
315+
elifnotexists:
316+
ifnameinself._unique_names:
317+
raiseValueError(f"{name!r} is already assigned.")
318+
self._unique_names.add(name)
313319
self.set_shape(name,value.shape)
314320
self.set_type(name,self._get_type(value.dtype))
315321
self.initializers_dict[name]=value
@@ -330,6 +336,9 @@ def make_tensor_input(
330336
else:
331337
self.input_names.append(name)
332338
input_name=name
339+
ifnameinself._unique_names:
340+
raiseValueError(f"{name!r} is already assigned.")
341+
self._unique_names.add(name)
333342
self.current_input+=1
334343
elem_type=self._get_type(elem_type)
335344
self.inputs.append(oh.make_tensor_value_info(input_name,elem_type,shape))
@@ -397,15 +406,11 @@ def make_node(
397406
try:
398407
node=oh.make_node(op_type,inputs,output_names,domain=domain,**kwargs)
399408
exceptTypeErrorase:
400-
iti= [type(i)foriininputs]
401-
ito= (
402-
[type(o)foroinoutputs]
403-
ifisinstance(outputs, (tuple,list))
404-
elseoutputs
405-
)
406409
raiseTypeError(
407410
f"A node{op_type!r} cannot be created with "
408-
f"inputs={inputs} (types={iti}), outputs={outputs} (types={ito}), "
411+
f"inputs={inputs} (types={[type(i)foriininputs]}), "
412+
f"outputs={outputs} "
413+
f"(types={[type(o)foroinoutputs]ifisinstance(outputs, (tuple,list))elseoutputs}), "
409414
f"domain={domain!r}, kwargs={kwargs}."
410415
)frome
411416
ifattributes:
@@ -474,14 +479,18 @@ def make_nodes(
474479
self.set_shape(name,builder._known_shapes[init])
475480
self.set_type(name,builder._known_types[init])
476481

477-
assertlen(input_names)==len(
478-
builder.inputs
479-
),f"Inconsistency between input_names={input_names} and inputs={builder.inputs}."
482+
assertlen(input_names)==len(builder.inputs), (
483+
f"Inconsistency between input_names={input_names} "
484+
f"and the other builder inputs={builder.inputs}."
485+
)
486+
480487
forname,inpinzip(input_names,builder.inputs):
481488
new_name=self.unique_name(f"{prefix}{inp.name}")
482-
self.set_shape(new_name,builder.get_shape(inp.name))
483-
self.set_type(new_name,builder.get_type(inp.name))
484489
renaming[inp.name]=new_name
490+
ifbuilder.has_shape(inp.name):
491+
self.set_shape(new_name,builder.get_shape(inp.name))
492+
ifbuilder.has_type(inp.name):
493+
self.set_type(new_name,builder.get_type(inp.name))
485494
self.make_node("Identity", [name], [new_name])
486495

487496
fornodeinbuilder.nodes:

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp