Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit71aa3a0

Browse files
authored
Add methods to update nodes in GraphAPI (#59)
* Add methods to update nodes* update doc
1 parent6718ee8 commit71aa3a0

File tree

5 files changed

+188
-4
lines changed

5 files changed

+188
-4
lines changed

‎CHANGELOGS.rst‎

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
Change Logs
22
===========
33

4+
0.2.0
5+
+++++
6+
7+
*:pr:`59`: add methods to update nodes in GraphAPI
8+
49
0.1.3
510
+++++
611

‎_doc/api/graph_api.rst‎

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ GraphBuilder
99
..autoclass::onnx_array_api.graph_api.GraphBuilder
1010
:members:
1111

12+
NodePattern
13+
===========
14+
15+
..autoclass::onnx_array_api.graph_api.NodePattern
16+
:members:
17+
1218
OptimizationOptions
1319
===================
1420

‎_unittests/ut_graph_api/test_graph_builder.py‎

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,64 @@ def test_make_nodes_noprefix(self):
376376
got=ref.run(None,feeds)
377377
self.assertEqualArray(expected,got[0])
378378

379+
deftest_node_pattern(self):
380+
model=onnx.parser.parse_model(
381+
"""
382+
<ir_version: 8, opset_import: [ "": 18]>
383+
agraph (float[N] x) => (float[N] z) {
384+
two = Constant <value_float=2.0> ()
385+
four = Add(two, two)
386+
z = Mul(x, four)
387+
}"""
388+
)
389+
gr=GraphBuilder(model)
390+
p=gr.np(index=0)
391+
r=repr(p)
392+
self.assertEqual("NodePattern(index=0, op_type=None, name=None)",r)
393+
394+
deftest_update_node_attribute(self):
395+
model=onnx.parser.parse_model(
396+
"""
397+
<ir_version: 8, opset_import: [ "": 18]>
398+
agraph (float[N] x) => (float[N] z) {
399+
two = Constant <value_float=2.0> ()
400+
four = Add(two, two)
401+
z = Mul(x, four)
402+
}"""
403+
)
404+
gr=GraphBuilder(model)
405+
self.assertEqual(len(gr.nodes),3)
406+
m=gr.update_attribute(gr.np(op_type="Constant"),value_float=float(1))
407+
self.assertEqual(m,1)
408+
self.assertEqual(len(gr.nodes),3)
409+
onx=gr.to_onnx()
410+
self.assertEqual(len(onx.graph.node),3)
411+
node=onx.graph.node[0]
412+
self.assertIn("f: 1",str(node))
413+
414+
deftest_delete_node_attribute(self):
415+
model=onnx.parser.parse_model(
416+
"""
417+
<ir_version: 8, opset_import: [ "": 18]>
418+
agraph (float[N] x) => (float[N] z) {
419+
two = Constant <value_float=2.0> ()
420+
four = Add(two, two)
421+
z = Mul(x, four)
422+
}"""
423+
)
424+
gr=GraphBuilder(model)
425+
self.assertEqual(len(gr.nodes),3)
426+
m=gr.update_attribute(
427+
gr.np(op_type="Constant"),value_float=gr.DELETE,value_int=1
428+
)
429+
self.assertEqual(m,1)
430+
self.assertEqual(len(gr.nodes),3)
431+
onx=gr.to_onnx()
432+
self.assertEqual(len(onx.graph.node),3)
433+
node=onx.graph.node[0]
434+
self.assertNotIn('name: "value_float"',str(node))
435+
self.assertIn("i: 1",str(node))
436+
379437

380438
if__name__=="__main__":
381439
unittest.main(verbosity=2)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .graph_builderimportGraphBuilder
1+
from .graph_builderimportGraphBuilder,NodePattern

‎onnx_array_api/graph_api/graph_builder.py‎

Lines changed: 118 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
importsys
22
fromfunctoolsimportpartial
3-
fromtypingimportAny,Dict,List,Optional,Sequence,Set,Tuple,Union
3+
fromtypingimportAny,Dict,Iterator,List,Optional,Sequence,Set,Tuple,Union
44
importnumpyasnp
55
fromonnx.defsimportonnx_opset_version
66
importonnx.helperasoh
@@ -30,6 +30,51 @@ def __init__(
3030
self.constant_size=constant_size
3131

3232

33+
classNodePattern:
34+
"""
35+
Class defining a matching pattern able to find nodes in a set of nodes.
36+
"""
37+
38+
def__init__(
39+
self,
40+
index:Optional[int]=None,
41+
op_type:Optional[str]=None,
42+
name:Optional[None]=None,
43+
):
44+
self.index=index
45+
self.op_type=op_type
46+
self.name=name
47+
48+
def__repr__(self):
49+
"usual"
50+
args= ["index","op_type","name"]
51+
sargs= []
52+
forainargs:
53+
ifa:
54+
sargs.append(f"{a}={getattr(self,a)!r}")
55+
returnf"{self.__class__.__name__}({', '.join(sargs)})"
56+
57+
deffind(self,graph:"GraphBuilder")->Iterator:
58+
"""
59+
Iterates on nodes matching the pattern.
60+
"""
61+
forindex,nodeinenumerate(graph.nodes):
62+
ifself.match(index,node):
63+
yieldnode
64+
65+
defmatch(self,index,node:NodeProto)->bool:
66+
"""
67+
Tells if a node is matching this pattern.
68+
"""
69+
ifself.indexisnotNoneandself.index!=index:
70+
returnFalse
71+
ifself.op_typeisnotNoneandself.op_type!=node.op_type:
72+
returnFalse
73+
ifself.nameisnotNoneandself.name!=node.name:
74+
returnFalse
75+
returnTrue
76+
77+
3378
classOpset:
3479
# defined for opset >= 18
3580
# name: number of expected outputs
@@ -168,7 +213,7 @@ def __init__(
168213
f"{type(target_opset_or_existing_proto)} is not supported."
169214
)
170215

171-
self.op=Opset(self,self.opsets[""])
216+
self.op=Opset(self,self.opsets[""])if""inself.opsetselseNone
172217
self._cache_array= []
173218

174219
def_get_tensor_shape(
@@ -749,7 +794,6 @@ def constant_folding(self):
749794
Folds all constants. Constants are marked during the creation of the graph.
750795
There is no need to propagate this information.
751796
"""
752-
753797
updates= {}
754798
node_to_remove=set()
755799
fork,vinself.constants_.items():
@@ -840,3 +884,74 @@ def remove_identity_nodes(self):
840884
self.nodes.append(new_node)
841885
else:
842886
self.nodes.append(node)
887+
888+
defnp(
889+
self,
890+
index:Optional[int]=None,
891+
op_type:Optional[str]=None,
892+
name:Optional[str]=None,
893+
)->NodePattern:
894+
"""
895+
Returns an instance of :class:`NodePattern
896+
<onnx_array_api.graph_api.graph_builder.NodePattern>`.
897+
"""
898+
returnNodePattern(index=index,op_type=op_type,name=name)
899+
900+
defupdate_attribute(
901+
self,
902+
pat:NodePattern,
903+
recursive:bool=False,
904+
**kwargs:Dict[str,Any],
905+
)->int:
906+
"""
907+
Udates attributes for nodes matching the
908+
909+
:param pat: returned by method :meth:`GraphBuilder.np`
910+
:param recursive: walk through subgraph
911+
:param kwargs: attributes to modify
912+
:return: number of modified nodes
913+
"""
914+
assertnotrecursive,"recursive=True is not implemented."
915+
modified=0
916+
fornodeinpat.find(self):
917+
up=self.update_node(node,**kwargs)
918+
ifup:
919+
modified+=1
920+
returnmodified
921+
922+
DELETE=object()
923+
924+
defupdate_node(self,node:NodeProto,**kwargs)->bool:
925+
"""
926+
Updates attributes of a node proto.
927+
Returns True if the node was updated.
928+
"""
929+
processed=set()
930+
modified=True
931+
atts= []
932+
forattinnode.attribute:
933+
ifatt.nameinkwargs:
934+
processed.add(att.name)
935+
ifkwargs[att.name]isGraphBuilder.DELETE:
936+
continue
937+
new_att=oh.make_attribute(att.name,kwargs[att.name])
938+
assertnew_att.type==att.type, (
939+
f"Mismatch value for attribute{att.name!r} has type "
940+
f"{att.type} but the new value leads to "
941+
f"type={new_att.type}."
942+
)
943+
atts.append(new_att)
944+
modified=True
945+
continue
946+
atts.append(att)
947+
fork,vinkwargs.items():
948+
ifkinprocessedorvisGraphBuilder.DELETE:
949+
continue
950+
modified=True
951+
new_att=oh.make_attribute(k,v)
952+
atts.append(new_att)
953+
954+
ifmodified:
955+
delnode.attribute[:]
956+
node.attribute.extend(atts)
957+
returnmodified

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp