@@ -836,30 +836,57 @@ def remove_identity_nodes(self):
836836"""
837837 Removes identity nodes.
838838 """
839- #f<irst pass: detect replacements
839+ #first pass: detect replacements
840840new_nodes = []
841841input_names = set (i .name for i in self .inputs )
842842output_names = set (i .name for i in self .outputs )
843843replacements = {}
844+ replacements_rev = {}
844845for node in self .nodes :
845846if node .op_type != "Identity" :
846847new_nodes .append (node )
847848continue
848849
849850if node .output [0 ]not in output_names :
850851old_name ,new_name = node .output [0 ],node .input [0 ]
851- elif node .input [0 ]not in input_names :
852+ elif (
853+ node .input [0 ]not in input_names
854+ and node .input [0 ]not in output_names
855+ and node .input [0 ]not in replacements
856+ ):
852857old_name ,new_name = node .input [0 ],node .output [0 ]
853858else :
854859new_nodes .append (node )
855860continue
856861
857862# the new name can be set for replacements as well
858- assert old_name not in replacements
859863if new_name in replacements :
860864new_name = replacements [new_name ]
861- assert new_name not in replacements
865+ assert new_name not in replacements , (
866+ f"Name{ old_name !r} still in{ replacements } , node.op_type={ node .op_type !r} , "
867+ f"node.input={ node .input } , node.output={ node .output } , "
868+ f"input_names={ input_names } , output_names={ output_names } "
869+ )
870+ if old_name in replacements_rev :
871+ old_old_name = replacements_rev [old_name ]
872+ replacements [old_old_name ]= new_name
873+ replacements_rev [new_name ]= old_old_name
874+ if old_name in replacements :
875+ replacements [replacements [old_name ]]= new_name
876+ assert new_name not in replacements , (
877+ f"Name{ old_name !r} still in{ replacements } , node.op_type={ node .op_type !r} , "
878+ f"node.input={ node .input } , node.output={ node .output } , "
879+ f"input_names={ input_names } , output_names={ output_names } "
880+ )
862881replacements [old_name ]= new_name
882+ replacements_rev [new_name ]= old_name
883+
884+ # verification
885+ for k ,v in replacements .items ():
886+ assert v not in replacements , (
887+ f"replacement{ k } ->{ v } is not possible because of "
888+ f"{ v } ->{ replacements [v ]} , old_name={ old_name !r} , new_name={ new_name !r} "
889+ )
863890
864891# second pass: replacements in initializer
865892for k ,v in replacements .items ():
@@ -876,10 +903,12 @@ def remove_identity_nodes(self):
876903repo = {o for o in node .output if o in replacements }
877904repi = {o for o in node .input if o in replacements }
878905if repi or repo :
906+ new_inputs = [replacements .get (i ,i )for i in node .input ]
907+ new_outputs = [replacements .get (i ,i )for i in node .output ]
879908new_node = oh .make_node (
880909node .op_type ,
881- [ replacements . get ( i , i ) for i in node . input ] ,
882- [ replacements . get ( i , i ) for i in node . output ] ,
910+ new_inputs ,
911+ new_outputs ,
883912domain = node .domain ,
884913name = node .name ,
885914 )