Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Need help with conversion to ONNX #211

Open
@surajs52

Description

@surajs52

Hey@foolwood,
i need help with conversion to ONNX format.
My python script using torch.onnx.export() for conversion looks like this:

from tools.test import *
#from siammask.models import Custom
from custom import Custom

parser = argparse.ArgumentParser(description='PyTorch Tracking Demo')

parser.add_argument('--resume', default='', type=str, required=True,
metavar='PATH',help='path to latest checkpoint (default: none)')
parser.add_argument('--config', dest='config', default='config_davis.json',
help='hyper-parameter of SiamMask in json format')
#parser.add_argument('--base_path', default='../../data/tennis', help='datasets')
#parser.add_argument('--cpu', action='store_true', help='cpu mode')
args = parser.parse_args()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True

cfg = load_config(args)
siammask = Custom(anchors=cfg['anchors'])

siammask.load_state_dict(torch.load('SiamMask_DAVIS.pth')["state_dict"])

siammask.eval().to(device)
siammask.half()

template = torch.randn(1, 3, 127, 127).to(device).half()
search = torch.randn(1, 3, 255, 255).to(device).half()
label_cls = torch.randn(1, 1, 5).to(device).half()
input_dict = {'template': template, 'search': search} #, 'label_cls': label_cls}

torch.onnx.export(siammask, input_dict, "SiamMask_DAVIS_half_test.onnx",
input_names=['template', 'search'],
opset_version=11,
do_constant_folding=True,
verbose=True,
output_names=['rpn_pred_cls', 'rpn_pred_loc', 'pred_mask'],
dynamic_axes={'search': {0: 'batch_size'}, # if you want batch size to be dynamic
'rpn_pred_cls': {0: 'batch_size'},
'rpn_pred_loc': {0: 'batch_size'},
'pred_mask': {0: 'batch_size'}})

The output looks like this:
[2024-02-14 15:40:21,552-rk0-features.py# 66] Current training 0 layers:

[2024-02-14 15:40:21,554-rk0-features.py# 66] Current training 1 layers:

====== Diagnostic Run torch.onnx.export version 1.14.0a0+44dac51c.nv23.01 ======
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================

Traceback (most recent call last):
File "../../tools/torch2onnx.py", line 78, in
torch.onnx.export(siammask, input_dict, "SiamMask_DAVIS_half_test.onnx",
File "/home/x/archiconda3/envs/siammask38/lib/python3.8/site-packages/torch/onnx/utils.py", line 506, in export
_export(
File "/home/x/archiconda3/envs/siammask38/lib/python3.8/site-packages/torch/onnx/utils.py", line 1533, in _export
graph, params_dict, torch_out = _model_to_graph(
File "/home/x/archiconda3/envs/siammask38/lib/python3.8/site-packages/torch/onnx/utils.py", line 1113, in _model_to_graph
graph, params, torch_out, module = _create_jit_graph(model, args)
File "/home/x/archiconda3/envs/siammask38/lib/python3.8/site-packages/torch/onnx/utils.py", line 989, in _create_jit_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args)
File "/home/x/archiconda3/envs/siammask38/lib/python3.8/site-packages/torch/onnx/utils.py", line 893, in _trace_and_get_graph_from_model
trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
File "/home/x/archiconda3/envs/siammask38/lib/python3.8/site-packages/torch/jit/_trace.py", line 1260, in _get_trace_graph
outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
File "/home/x/archiconda3/envs/siammask38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1480, in _call_impl
return forward_call(*args, **kwargs)
File "/home/x/archiconda3/envs/siammask38/lib/python3.8/site-packages/torch/jit/_trace.py", line 127, in forward
graph, out = torch._C._create_graph_by_tracing(
File "/home/x/archiconda3/envs/siammask38/lib/python3.8/site-packages/torch/jit/_trace.py", line 118, in wrapper
outs.append(self.inner(*trace_inputs))
File "/home/x/archiconda3/envs/siammask38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1480, in _call_impl
return forward_call(*args, **kwargs)
File "/home/x/archiconda3/envs/siammask38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1467, in _slow_forward
result = self.forward(*input, **kwargs)
TypeError: forward() missing 1 required positional argument: 'input'

To be specific, i need help figuring out the exact set of input and output parameters for torch.onnx.export() to perform the conversion.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions


      [8]ページ先頭

      ©2009-2025 Movatter.jp