|
1 | 1 | importsys |
2 | 2 | fromfunctoolsimportpartial |
3 | | -fromtypingimportAny,Dict,List,Optional,Sequence,Set,Tuple,Union |
| 3 | +fromtypingimportAny,Dict,Iterator,List,Optional,Sequence,Set,Tuple,Union |
4 | 4 | importnumpyasnp |
5 | 5 | fromonnx.defsimportonnx_opset_version |
6 | 6 | importonnx.helperasoh |
@@ -30,6 +30,51 @@ def __init__( |
30 | 30 | self.constant_size=constant_size |
31 | 31 |
|
32 | 32 |
|
| 33 | +classNodePattern: |
| 34 | +""" |
| 35 | + Class defining a matching pattern able to find nodes in a set of nodes. |
| 36 | + """ |
| 37 | + |
| 38 | +def__init__( |
| 39 | +self, |
| 40 | +index:Optional[int]=None, |
| 41 | +op_type:Optional[str]=None, |
| 42 | +name:Optional[None]=None, |
| 43 | + ): |
| 44 | +self.index=index |
| 45 | +self.op_type=op_type |
| 46 | +self.name=name |
| 47 | + |
| 48 | +def__repr__(self): |
| 49 | +"usual" |
| 50 | +args= ["index","op_type","name"] |
| 51 | +sargs= [] |
| 52 | +forainargs: |
| 53 | +ifa: |
| 54 | +sargs.append(f"{a}={getattr(self,a)!r}") |
| 55 | +returnf"{self.__class__.__name__}({', '.join(sargs)})" |
| 56 | + |
| 57 | +deffind(self,graph:"GraphBuilder")->Iterator: |
| 58 | +""" |
| 59 | + Iterates on nodes matching the pattern. |
| 60 | + """ |
| 61 | +forindex,nodeinenumerate(graph.nodes): |
| 62 | +ifself.match(index,node): |
| 63 | +yieldnode |
| 64 | + |
| 65 | +defmatch(self,index,node:NodeProto)->bool: |
| 66 | +""" |
| 67 | + Tells if a node is matching this pattern. |
| 68 | + """ |
| 69 | +ifself.indexisnotNoneandself.index!=index: |
| 70 | +returnFalse |
| 71 | +ifself.op_typeisnotNoneandself.op_type!=node.op_type: |
| 72 | +returnFalse |
| 73 | +ifself.nameisnotNoneandself.name!=node.name: |
| 74 | +returnFalse |
| 75 | +returnTrue |
| 76 | + |
| 77 | + |
33 | 78 | classOpset: |
34 | 79 | # defined for opset >= 18 |
35 | 80 | # name: number of expected outputs |
@@ -749,7 +794,6 @@ def constant_folding(self): |
749 | 794 | Folds all constants. Constants are marked during the creation of the graph. |
750 | 795 | There is no need to propagate this information. |
751 | 796 | """ |
752 | | - |
753 | 797 | updates= {} |
754 | 798 | node_to_remove=set() |
755 | 799 | fork,vinself.constants_.items(): |
@@ -840,3 +884,71 @@ def remove_identity_nodes(self): |
840 | 884 | self.nodes.append(new_node) |
841 | 885 | else: |
842 | 886 | self.nodes.append(node) |
| 887 | + |
| 888 | +defnp( |
| 889 | +self, |
| 890 | +index:Optional[int]=None, |
| 891 | +op_type:Optional[str]=None, |
| 892 | +name:Optional[str]=None, |
| 893 | + )->NodePattern: |
| 894 | +"Returns an instance of :class:`NodePattern`." |
| 895 | +returnNodePattern(index=index,op_type=op_type,name=name) |
| 896 | + |
| 897 | +defupdate_attribute( |
| 898 | +self, |
| 899 | +pat:NodePattern, |
| 900 | +recursive:bool=False, |
| 901 | +**kwargs:Dict[str,Any], |
| 902 | + )->int: |
| 903 | +""" |
| 904 | + Udates attributes for nodes matching the |
| 905 | +
|
| 906 | + :param pat: returned by method :meth:`GraphBuilder.np` |
| 907 | + :param recursive: walk through subgraph |
| 908 | + :param kwargs: attributes to modify |
| 909 | + :return: number of modified nodes |
| 910 | + """ |
| 911 | +assertnotrecursive,"recursive=True is not implemented." |
| 912 | +modified=0 |
| 913 | +fornodeinpat.find(self): |
| 914 | +up=self.update_node(node,**kwargs) |
| 915 | +ifup: |
| 916 | +modified+=1 |
| 917 | +returnmodified |
| 918 | + |
| 919 | +DELETE=object() |
| 920 | + |
| 921 | +defupdate_node(self,node:NodeProto,**kwargs)->bool: |
| 922 | +""" |
| 923 | + Updates attributes of a node proto. |
| 924 | + Returns True if the node was updated. |
| 925 | + """ |
| 926 | +processed=set() |
| 927 | +modified=True |
| 928 | +atts= [] |
| 929 | +forattinnode.attribute: |
| 930 | +ifatt.nameinkwargs: |
| 931 | +processed.add(att.name) |
| 932 | +ifkwargs[att.name]isGraphBuilder.DELETE: |
| 933 | +continue |
| 934 | +new_att=oh.make_attribute(att.name,kwargs[att.name]) |
| 935 | +assertnew_att.type==att.type, ( |
| 936 | +f"Mismatch value for attribute{att.name!r} has type " |
| 937 | +f"{att.type} but the new value leads to " |
| 938 | +f"type={new_att.type}." |
| 939 | + ) |
| 940 | +atts.append(new_att) |
| 941 | +modified=True |
| 942 | +continue |
| 943 | +atts.append(att) |
| 944 | +fork,vinkwargs.items(): |
| 945 | +ifkinprocessedorvisGraphBuilder.DELETE: |
| 946 | +continue |
| 947 | +modified=True |
| 948 | +new_att=oh.make_attribute(k,v) |
| 949 | +atts.append(new_att) |
| 950 | + |
| 951 | +ifmodified: |
| 952 | +delnode.attribute[:] |
| 953 | +node.attribute.extend(atts) |
| 954 | +returnmodified |