Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit68eb05e

Browse files
committed
Add methods to update nodes
1 parent6718ee8 commit68eb05e

File tree

3 files changed

+178
-2
lines changed

3 files changed

+178
-2
lines changed

‎_doc/api/graph_api.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ GraphBuilder
99
..autoclass::onnx_array_api.graph_api.GraphBuilder
1010
:members:
1111

12+
NodePattern
13+
===========
14+
15+
..autoclass::onnx_array_api.graph_api.NodePattern
16+
:members:
17+
1218
OptimizationOptions
1319
===================
1420

‎_unittests/ut_graph_api/test_graph_builder.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,64 @@ def test_make_nodes_noprefix(self):
376376
got=ref.run(None,feeds)
377377
self.assertEqualArray(expected,got[0])
378378

379+
deftest_node_pattern(self):
380+
model=onnx.parser.parse_model(
381+
"""
382+
<ir_version: 8, opset_import: [ "": 18]>
383+
agraph (float[N] x) => (float[N] z) {
384+
two = Constant <value_float=2.0> ()
385+
four = Add(two, two)
386+
z = Mul(x, four)
387+
}"""
388+
)
389+
gr=GraphBuilder(model)
390+
p=gr.np(index=0)
391+
r=repr(p)
392+
self.assertEqual("NodePattern(index=0, op_type=None, name=None)",r)
393+
394+
deftest_update_node_attribute(self):
395+
model=onnx.parser.parse_model(
396+
"""
397+
<ir_version: 8, opset_import: [ "": 18]>
398+
agraph (float[N] x) => (float[N] z) {
399+
two = Constant <value_float=2.0> ()
400+
four = Add(two, two)
401+
z = Mul(x, four)
402+
}"""
403+
)
404+
gr=GraphBuilder(model)
405+
self.assertEqual(len(gr.nodes),3)
406+
m=gr.update_attribute(gr.np(op_type="Constant"),value_float=float(1))
407+
self.assertEqual(m,1)
408+
self.assertEqual(len(gr.nodes),3)
409+
onx=gr.to_onnx()
410+
self.assertEqual(len(onx.graph.node),3)
411+
node=onx.graph.node[0]
412+
self.assertIn("f: 1",str(node))
413+
414+
deftest_delete_node_attribute(self):
415+
model=onnx.parser.parse_model(
416+
"""
417+
<ir_version: 8, opset_import: [ "": 18]>
418+
agraph (float[N] x) => (float[N] z) {
419+
two = Constant <value_float=2.0> ()
420+
four = Add(two, two)
421+
z = Mul(x, four)
422+
}"""
423+
)
424+
gr=GraphBuilder(model)
425+
self.assertEqual(len(gr.nodes),3)
426+
m=gr.update_attribute(
427+
gr.np(op_type="Constant"),value_float=gr.DELETE,value_int=1
428+
)
429+
self.assertEqual(m,1)
430+
self.assertEqual(len(gr.nodes),3)
431+
onx=gr.to_onnx()
432+
self.assertEqual(len(onx.graph.node),3)
433+
node=onx.graph.node[0]
434+
self.assertNotIn('name: "value_float"',str(node))
435+
self.assertIn("i: 1",str(node))
436+
379437

380438
if__name__=="__main__":
381439
unittest.main(verbosity=2)

‎onnx_array_api/graph_api/graph_builder.py

Lines changed: 114 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
importsys
22
fromfunctoolsimportpartial
3-
fromtypingimportAny,Dict,List,Optional,Sequence,Set,Tuple,Union
3+
fromtypingimportAny,Dict,Iterator,List,Optional,Sequence,Set,Tuple,Union
44
importnumpyasnp
55
fromonnx.defsimportonnx_opset_version
66
importonnx.helperasoh
@@ -30,6 +30,51 @@ def __init__(
3030
self.constant_size=constant_size
3131

3232

33+
classNodePattern:
34+
"""
35+
Class defining a matching pattern able to find nodes in a set of nodes.
36+
"""
37+
38+
def__init__(
39+
self,
40+
index:Optional[int]=None,
41+
op_type:Optional[str]=None,
42+
name:Optional[None]=None,
43+
):
44+
self.index=index
45+
self.op_type=op_type
46+
self.name=name
47+
48+
def__repr__(self):
49+
"usual"
50+
args= ["index","op_type","name"]
51+
sargs= []
52+
forainargs:
53+
ifa:
54+
sargs.append(f"{a}={getattr(self,a)!r}")
55+
returnf"{self.__class__.__name__}({', '.join(sargs)})"
56+
57+
deffind(self,graph:"GraphBuilder")->Iterator:
58+
"""
59+
Iterates on nodes matching the pattern.
60+
"""
61+
forindex,nodeinenumerate(graph.nodes):
62+
ifself.match(index,node):
63+
yieldnode
64+
65+
defmatch(self,index,node:NodeProto)->bool:
66+
"""
67+
Tells if a node is matching this pattern.
68+
"""
69+
ifself.indexisnotNoneandself.index!=index:
70+
returnFalse
71+
ifself.op_typeisnotNoneandself.op_type!=node.op_type:
72+
returnFalse
73+
ifself.nameisnotNoneandself.name!=node.name:
74+
returnFalse
75+
returnTrue
76+
77+
3378
classOpset:
3479
# defined for opset >= 18
3580
# name: number of expected outputs
@@ -749,7 +794,6 @@ def constant_folding(self):
749794
Folds all constants. Constants are marked during the creation of the graph.
750795
There is no need to propagate this information.
751796
"""
752-
753797
updates= {}
754798
node_to_remove=set()
755799
fork,vinself.constants_.items():
@@ -840,3 +884,71 @@ def remove_identity_nodes(self):
840884
self.nodes.append(new_node)
841885
else:
842886
self.nodes.append(node)
887+
888+
defnp(
889+
self,
890+
index:Optional[int]=None,
891+
op_type:Optional[str]=None,
892+
name:Optional[str]=None,
893+
)->NodePattern:
894+
"Returns an instance of :class:`NodePattern`."
895+
returnNodePattern(index=index,op_type=op_type,name=name)
896+
897+
defupdate_attribute(
898+
self,
899+
pat:NodePattern,
900+
recursive:bool=False,
901+
**kwargs:Dict[str,Any],
902+
)->int:
903+
"""
904+
Udates attributes for nodes matching the
905+
906+
:param pat: returned by method :meth:`GraphBuilder.np`
907+
:param recursive: walk through subgraph
908+
:param kwargs: attributes to modify
909+
:return: number of modified nodes
910+
"""
911+
assertnotrecursive,"recursive=True is not implemented."
912+
modified=0
913+
fornodeinpat.find(self):
914+
up=self.update_node(node,**kwargs)
915+
ifup:
916+
modified+=1
917+
returnmodified
918+
919+
DELETE=object()
920+
921+
defupdate_node(self,node:NodeProto,**kwargs)->bool:
922+
"""
923+
Updates attributes of a node proto.
924+
Returns True if the node was updated.
925+
"""
926+
processed=set()
927+
modified=True
928+
atts= []
929+
forattinnode.attribute:
930+
ifatt.nameinkwargs:
931+
processed.add(att.name)
932+
ifkwargs[att.name]isGraphBuilder.DELETE:
933+
continue
934+
new_att=oh.make_attribute(att.name,kwargs[att.name])
935+
assertnew_att.type==att.type, (
936+
f"Mismatch value for attribute{att.name!r} has type "
937+
f"{att.type} but the new value leads to "
938+
f"type={new_att.type}."
939+
)
940+
atts.append(new_att)
941+
modified=True
942+
continue
943+
atts.append(att)
944+
fork,vinkwargs.items():
945+
ifkinprocessedorvisGraphBuilder.DELETE:
946+
continue
947+
modified=True
948+
new_att=oh.make_attribute(k,v)
949+
atts.append(new_att)
950+
951+
ifmodified:
952+
delnode.attribute[:]
953+
node.attribute.extend(atts)
954+
returnmodified

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp