33import numpy as np
44import onnx .helper as oh
55import onnx .numpy_helper as onh
6- from onnx import AttributeProto ,FunctionProto ,ModelProto ,NodeProto ,TensorProto
6+ from onnx import (
7+ AttributeProto ,
8+ FunctionProto ,
9+ GraphProto ,
10+ ModelProto ,
11+ NodeProto ,
12+ TensorProto ,
13+ )
714from onnx .reference import ReferenceEvaluator
815
916T = "TENSOR"
@@ -655,6 +662,22 @@ def optimize(self, check_order: bool = False):
655662if check_order :
656663self .check_order ()
657664
665+ def hidden_inputs_graph (self ,graph :GraphProto )-> Set [str ]:
666+ hidden = set ()
667+ memo = set (i .name for i in graph .initializer )
668+ memo |= set (i .name for i in graph .sparse_initializer )
669+ for node in graph .node :
670+ for i in node .input :
671+ if i not in memo :
672+ hidden .add (i )
673+ for att in node .attribute :
674+ if att .type == AttributeProto .GRAPH and att .g :
675+ hid = self .hidden_inputs_graph (att .g )
676+ less = set (h for h in hid if h not in memo )
677+ hidden |= less
678+ memo |= set (node .output )
679+ return hidden
680+
658681def remove_unused (self ):
659682"""
660683 Simple function to remove unused nodes.
@@ -671,6 +694,11 @@ def remove_unused(self):
671694for i in node .input :
672695marked [o ].add (i )
673696used = True
697+ for att in node .attribute :
698+ if att .type == AttributeProto .GRAPH and att .g :
699+ hidden_inputs = self .hidden_inputs_graph (att .g )
700+ for i in hidden_inputs :
701+ marked [i ]= set ()
674702if used :
675703for i in node .input :
676704marked [i ]= set ()