Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit9907a3e

Browse files
lara-hdrfacebook-github-bot
authored andcommitted
Update Argmin/Argmax ONNX Export (#38329)
Summary:Update Argmin/Argmax ONNX export in opset 12 to export with "select_last_index", and export correctly cases where the same value appears multiple time in the input tensor.Pull Requestresolved:#38329Reviewed By: hl475Differential Revision: D21613799Pulled By: houseroadfbshipit-source-id: 4597e23561f444c4e56d30c735dae7e9a8a41c5e
1 parentcbd0adc commit9907a3e

File tree

2 files changed

+49
-1
lines changed

2 files changed

+49
-1
lines changed

‎test/onnx/test_pytorch_onnx_onnxruntime.py‎

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2045,6 +2045,32 @@ def forward(self, input, other):
20452045
y=torch.randint(10, (2,4,5))
20462046
self.run_test(MatmulModel(), (x,y))
20472047

2048+
def_argmin_argmax_model(self,input):
2049+
classArgminArgmaxModel(torch.nn.Module):
2050+
defforward(self,input):
2051+
returntorch.argmin(input), \
2052+
torch.argmax(input), \
2053+
torch.argmin(input,keepdim=True), \
2054+
torch.argmax(input,keepdim=True)
2055+
2056+
self.run_test(ArgminArgmaxModel(),input)
2057+
2058+
deftest_argmin_argmax(self):
2059+
input=torch.randn(7,3,5)
2060+
self._argmin_argmax_model(input)
2061+
2062+
# Argmin and Argmax with "select_last_index" is not supprted before opset 12
2063+
# "select_last_index" was added in opset 12 to deal with corner case where the
2064+
# same value appears multiple times in the tensor
2065+
@skipIfUnsupportedMinOpsetVersion(12)
2066+
deftest_argmin_argmax_select_last_index(self):
2067+
input=torch.tensor([[1.,2.,3.],
2068+
[1.,1.,2.]])
2069+
self._argmin_argmax_model(input)
2070+
2071+
input=torch.ones(7,3,5)
2072+
self._argmin_argmax_model(input)
2073+
20482074
deftest_view(self):
20492075
classViewModel(torch.nn.Module):
20502076
defforward(self,input):

‎torch/onnx/symbolic_opset12.py‎

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
importtorch
44
importtorch.onnx.symbolic_helperassym_help
5-
fromtorch.onnx.symbolic_helperimportparse_args
5+
fromtorch.onnx.symbolic_helperimportparse_args,_parse_arg
66

77

88
# EDITING THIS FILE? READ THIS FIRST!
@@ -62,6 +62,28 @@ def nll_loss2d(g, self, target, weight, reduction, ignore_index):
6262
returnnll_loss(g,self,target,weight,reduction,ignore_index)
6363

6464

65+
defargmax(g,input,dim,keepdim):
66+
ifsym_help._is_none(dim):
67+
fromtorch.onnx.symbolic_opset9importreshape
68+
flattened=reshape(g,input, (-1,))
69+
returng.op('ArgMax',flattened,axis_i=0,keepdims_i=False,select_last_index_i=True)
70+
else:
71+
dim=_parse_arg(dim,'i')
72+
keepdim=_parse_arg(keepdim,'i')
73+
returng.op('ArgMax',input,axis_i=dim,keepdims_i=keepdim,select_last_index_i=True)
74+
75+
76+
defargmin(g,input,dim,keepdim):
77+
ifsym_help._is_none(dim):
78+
fromtorch.onnx.symbolic_opset9importreshape
79+
flattened=reshape(g,input, (-1,))
80+
returng.op('ArgMin',flattened,axis_i=0,keepdims_i=False,select_last_index_i=True)
81+
else:
82+
dim=_parse_arg(dim,'i')
83+
keepdim=_parse_arg(keepdim,'i')
84+
returng.op('ArgMin',input,axis_i=dim,keepdims_i=keepdim,select_last_index_i=True)
85+
86+
6587
defpow(g,self,exponent):
6688
returng.op("Pow",self,exponent)
6789

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp