Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit7330b58

Browse files
committed
remove some torch issues
1 parent395e281 commit7330b58

File tree

1 file changed

+8
-23
lines changed

1 file changed

+8
-23
lines changed

‎onnx_array_api/graph_api/graph_builder.py‎

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
fromonnximportAttributeProto,FunctionProto,ModelProto,NodeProto,TensorProto
77
fromonnx.referenceimportReferenceEvaluator
88

9+
T="TENSOR"
10+
911

1012
classOpset:
1113
# defined for opset >= 18
@@ -78,8 +80,8 @@ def make_node(
7880
classOptimizationOptions:
7981
def__init__(
8082
self,
81-
remove_unused:bool=False,
82-
constant_folding:bool=True,
83+
remove_unused:bool=True,
84+
constant_folding:bool=False,
8385
constant_size:int=1024,
8486
):
8587
self.remove_unused=remove_unused
@@ -205,10 +207,6 @@ def get_constant(self, name: str) -> np.ndarray:
205207
ifisinstance(value,np.ndarray):
206208
returnvalue
207209

208-
importtorch
209-
210-
ifisinstance(value,torch.Tensor):
211-
returnvalue.detach().numpy()
212210
raiseTypeError(f"Unable to convert type{type(value)} into numpy array.")
213211

214212
defset_shape(self,name:str,shape:Tuple[int, ...]):
@@ -513,9 +511,7 @@ def make_nodes(
513511
returnoutput_names[0]
514512
returnoutput_names
515513

516-
deffrom_array(
517-
self,arr:"torch.Tensor",name:str=None# noqa: F821
518-
)->TensorProto:
514+
deffrom_array(self,arr:T,name:str=None)->TensorProto:# noqa: F821
519515
importsys
520516
importtorch
521517

@@ -552,15 +548,8 @@ def from_array(
552548
returntensor
553549

554550
def_build_initializers(self)->List[TensorProto]:
555-
importtorch
556-
557551
res= []
558552
fork,vinsorted(self.initializers_dict.items()):
559-
ifisinstance(v,torch.Tensor):
560-
# no string tensor
561-
t=self.from_array(v,name=k)
562-
res.append(t)
563-
continue
564553
ifisinstance(v,np.ndarray):
565554
ifself.verboseandnp.prod(v.shape)>100:
566555
print(f"[GraphBuilder] onh.from_array:{k}:{v.dtype}[{v.shape}]")
@@ -575,7 +564,7 @@ def _build_initializers(self) -> List[TensorProto]:
575564

576565
defprocess(
577566
self,
578-
graph_module:"torch.f.GraphModule",# noqa: F821
567+
graph_module:Any,
579568
interpreter:"Interpreter",# noqa: F821
580569
):
581570
fornodeingraph_module.graph.nodes:
@@ -656,19 +645,15 @@ def remove_unused(self):
656645
self.constants_= {k:vfork,vinself.constants_.items()ifkinmarked}
657646
self.nodes= [nodefori,nodeinenumerate(self.nodes)ifinotinremoved]
658647

659-
def_apply_transpose(
660-
self,node:NodeProto,feeds:Dict[str,"torch.Tensor"]# noqa: F821
661-
)->"torch.Tensor":# noqa: F821
662-
importtorch
663-
648+
def_apply_transpose(self,node:NodeProto,feeds:Dict[str,T])->T:# noqa: F821
664649
perm=None
665650
forattinnode.attribute:
666651
ifatt.name=="perm":
667652
perm=tuple(att.ints)
668653
break
669654
assertperm,f"perm not here in node{node}"
670655
assertlen(perm)==2,f"perm={perm} is not supported with torch"
671-
return [torch.transpose(feeds[node.input[0]],*perm)]
656+
return [np.transpose(feeds[node.input[0]],*perm)]
672657

673658
defconstant_folding(self):
674659
"""

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp