66from onnx import AttributeProto ,FunctionProto ,ModelProto ,NodeProto ,TensorProto
77from onnx .reference import ReferenceEvaluator
88
9+ T = "TENSOR"
10+
911
1012class Opset :
1113# defined for opset >= 18
@@ -78,8 +80,8 @@ def make_node(
7880class OptimizationOptions :
7981def __init__ (
8082self ,
81- remove_unused :bool = False ,
82- constant_folding :bool = True ,
83+ remove_unused :bool = True ,
84+ constant_folding :bool = False ,
8385constant_size :int = 1024 ,
8486 ):
8587self .remove_unused = remove_unused
@@ -205,10 +207,6 @@ def get_constant(self, name: str) -> np.ndarray:
205207if isinstance (value ,np .ndarray ):
206208return value
207209
208- import torch
209-
210- if isinstance (value ,torch .Tensor ):
211- return value .detach ().numpy ()
212210raise TypeError (f"Unable to convert type{ type (value )} into numpy array." )
213211
214212def set_shape (self ,name :str ,shape :Tuple [int , ...]):
@@ -513,9 +511,7 @@ def make_nodes(
513511return output_names [0 ]
514512return output_names
515513
516- def from_array (
517- self ,arr :"torch.Tensor" ,name :str = None # noqa: F821
518- )-> TensorProto :
514+ def from_array (self ,arr :T ,name :str = None )-> TensorProto :# noqa: F821
519515import sys
520516import torch
521517
@@ -552,15 +548,8 @@ def from_array(
552548return tensor
553549
554550def _build_initializers (self )-> List [TensorProto ]:
555- import torch
556-
557551res = []
558552for k ,v in sorted (self .initializers_dict .items ()):
559- if isinstance (v ,torch .Tensor ):
560- # no string tensor
561- t = self .from_array (v ,name = k )
562- res .append (t )
563- continue
564553if isinstance (v ,np .ndarray ):
565554if self .verbose and np .prod (v .shape )> 100 :
566555print (f"[GraphBuilder] onh.from_array:{ k } :{ v .dtype } [{ v .shape } ]" )
@@ -575,7 +564,7 @@ def _build_initializers(self) -> List[TensorProto]:
575564
576565def process (
577566self ,
578- graph_module :"torch.f.GraphModule" , # noqa: F821
567+ graph_module :Any ,
579568interpreter :"Interpreter" ,# noqa: F821
580569 ):
581570for node in graph_module .graph .nodes :
@@ -656,19 +645,15 @@ def remove_unused(self):
656645self .constants_ = {k :v for k ,v in self .constants_ .items ()if k in marked }
657646self .nodes = [node for i ,node in enumerate (self .nodes )if i not in removed ]
658647
659- def _apply_transpose (
660- self ,node :NodeProto ,feeds :Dict [str ,"torch.Tensor" ]# noqa: F821
661- )-> "torch.Tensor" :# noqa: F821
662- import torch
663-
648+ def _apply_transpose (self ,node :NodeProto ,feeds :Dict [str ,T ])-> T :# noqa: F821
664649perm = None
665650for att in node .attribute :
666651if att .name == "perm" :
667652perm = tuple (att .ints )
668653break
669654assert perm ,f"perm not here in node{ node } "
670655assert len (perm )== 2 ,f"perm={ perm } is not supported with torch"
671- return [torch .transpose (feeds [node .input [0 ]],* perm )]
656+ return [np .transpose (feeds [node .input [0 ]],* perm )]
672657
673658def constant_folding (self ):
674659"""