Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit6532733

Browse files
committed
add method check_order
1 parentb73a0cb commit6532733

File tree

2 files changed

+54
-5
lines changed

2 files changed

+54
-5
lines changed

‎_unittests/ut_graph_api/test_graph_builder_optim.py‎

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
importos
22
importunittest
33
importonnx
4+
fromonnx.inlinerimportinline_local_functions
45
fromonnx_array_api.ext_test_caseimportExtTestCase
56
fromonnx_array_api.graph_api.graph_builderimportGraphBuilder
67

@@ -54,7 +55,7 @@ def test_keep_unused_outputs(self):
5455
self.assertEqual(len(onx.graph.node),2)
5556
self.assertEqual(onx.graph.node[0].op_type,"Split")
5657

57-
deftest_check_files(self):
58+
deftest_check_afiles(self):
5859
importonnxruntime
5960

6061
data=os.path.join(os.path.dirname(__file__),"data")
@@ -66,8 +67,14 @@ def test_check_files(self):
6667
os.path.join(data,f),providers=["CPUExecutionProvider"]
6768
)
6869
assertsess
69-
g=GraphBuilder(onx)
70-
g.optimize()
70+
onxi=inline_local_functions(onx)
71+
sess=onnxruntime.InferenceSession(
72+
onxi.SerializeToString(),providers=["CPUExecutionProvider"]
73+
)
74+
assertsess
75+
g=GraphBuilder(onxi)
76+
g.optimize(check_order=True)
77+
g.check_order()
7178
onx2=g.to_onnx()
7279
sess2=onnxruntime.InferenceSession(
7380
onx2.SerializeToString(),providers=["CPUExecutionProvider"]

‎onnx_array_api/graph_api/graph_builder.py‎

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
fromfunctoolsimportpartial
2-
fromtypingimportAny,Dict,List,Optional,Sequence,Tuple,Union
2+
fromtypingimportAny,Dict,List,Optional,Sequence,Set,Tuple,Union
33
importnumpyasnp
44
importonnx.helperasoh
55
importonnx.numpy_helperasonh
@@ -604,14 +604,56 @@ def to_onnx(
604604
model=oh.make_model(graph,opset_imports=opsets)
605605
returnmodel
606606

607-
defoptimize(self):
607+
def_check_order_node(self,ind:int,node:NodeProto,existing:Set[str]):
608+
foriinnode.input:
609+
ifinotinexisting:
610+
raiseRuntimeError(
611+
f"Unknown input{i!r} from node{ind}:{node.op_type}:{node.name}. "
612+
f"Known:{existing}."
613+
)
614+
forattinnode.attribute:
615+
ifatt.type==AttributeProto.GRAPHandatt.g:
616+
g_existing=existing.copy()
617+
foriinatt.g.input:
618+
g_existing.add(i.name)
619+
forind2,node2inenumerate(att.g.node):
620+
self._check_order_node((ind,ind2),node2,g_existing)
621+
foroinatt.g.output:
622+
ifo.namenoting_existing:
623+
raiseRuntimeError(
624+
f"Unknown output{o.name!r}. Known:{g_existing}."
625+
)
626+
foroinnode.output:
627+
existing.add(o)
628+
629+
defcheck_order(self):
630+
existing=set(self.initializers_dict)
631+
foriinself.inputs:
632+
existing.add(i.name)
633+
forind,nodeinenumerate(self.nodes):
634+
self._check_order_node(ind,node,existing)
635+
foroinself.outputs:
636+
ifo.namenotinexisting:
637+
raiseRuntimeError(f"Unknown output{o.name!r}. Known:{existing}.")
638+
639+
defoptimize(self,check_order:bool=False):
640+
ifcheck_order:
641+
self.check_order()
608642
self.remove_identity_nodes()
643+
ifcheck_order:
644+
self.check_order()
609645
ifself.optimization_options.remove_unused:
610646
self.remove_unused()
647+
ifcheck_order:
648+
self.check_order()
611649
ifself.optimization_options.constant_folding:
612650
self.constant_folding()
651+
ifcheck_order:
652+
self.check_order()
613653
ifself.optimization_options.remove_unused:
614654
self.remove_unused()
655+
ifcheck_order:
656+
self.check_order()
615657

616658
defremove_unused(self):
617659
"""

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp