- Notifications
You must be signed in to change notification settings - Fork809
Description
Hey@foolwood,
i need help with conversion to ONNX format.
My python script using torch.onnx.export() for conversion looks like this:
from tools.test import *
#from siammask.models import Custom
from custom import Custom
parser = argparse.ArgumentParser(description='PyTorch Tracking Demo')
parser.add_argument('--resume', default='', type=str, required=True,
metavar='PATH',help='path to latest checkpoint (default: none)')
parser.add_argument('--config', dest='config', default='config_davis.json',
help='hyper-parameter of SiamMask in json format')
#parser.add_argument('--base_path', default='../../data/tennis', help='datasets')
#parser.add_argument('--cpu', action='store_true', help='cpu mode')
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True
cfg = load_config(args)
siammask = Custom(anchors=cfg['anchors'])
siammask.load_state_dict(torch.load('SiamMask_DAVIS.pth')["state_dict"])
siammask.eval().to(device)
siammask.half()
template = torch.randn(1, 3, 127, 127).to(device).half()
search = torch.randn(1, 3, 255, 255).to(device).half()
label_cls = torch.randn(1, 1, 5).to(device).half()
input_dict = {'template': template, 'search': search} #, 'label_cls': label_cls}
torch.onnx.export(siammask, input_dict, "SiamMask_DAVIS_half_test.onnx",
input_names=['template', 'search'],
opset_version=11,
do_constant_folding=True,
verbose=True,
output_names=['rpn_pred_cls', 'rpn_pred_loc', 'pred_mask'],
dynamic_axes={'search': {0: 'batch_size'}, # if you want batch size to be dynamic
'rpn_pred_cls': {0: 'batch_size'},
'rpn_pred_loc': {0: 'batch_size'},
'pred_mask': {0: 'batch_size'}})
The output looks like this:
[2024-02-14 15:40:21,552-rk0-features.py# 66] Current training 0 layers:
[2024-02-14 15:40:21,554-rk0-features.py# 66] Current training 1 layers:
====== Diagnostic Run torch.onnx.export version 1.14.0a0+44dac51c.nv23.01 ======
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================
Traceback (most recent call last):
File "../../tools/torch2onnx.py", line 78, in
torch.onnx.export(siammask, input_dict, "SiamMask_DAVIS_half_test.onnx",
File "/home/x/archiconda3/envs/siammask38/lib/python3.8/site-packages/torch/onnx/utils.py", line 506, in export
_export(
File "/home/x/archiconda3/envs/siammask38/lib/python3.8/site-packages/torch/onnx/utils.py", line 1533, in _export
graph, params_dict, torch_out = _model_to_graph(
File "/home/x/archiconda3/envs/siammask38/lib/python3.8/site-packages/torch/onnx/utils.py", line 1113, in _model_to_graph
graph, params, torch_out, module = _create_jit_graph(model, args)
File "/home/x/archiconda3/envs/siammask38/lib/python3.8/site-packages/torch/onnx/utils.py", line 989, in _create_jit_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args)
File "/home/x/archiconda3/envs/siammask38/lib/python3.8/site-packages/torch/onnx/utils.py", line 893, in _trace_and_get_graph_from_model
trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
File "/home/x/archiconda3/envs/siammask38/lib/python3.8/site-packages/torch/jit/_trace.py", line 1260, in _get_trace_graph
outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
File "/home/x/archiconda3/envs/siammask38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1480, in _call_impl
return forward_call(*args, **kwargs)
File "/home/x/archiconda3/envs/siammask38/lib/python3.8/site-packages/torch/jit/_trace.py", line 127, in forward
graph, out = torch._C._create_graph_by_tracing(
File "/home/x/archiconda3/envs/siammask38/lib/python3.8/site-packages/torch/jit/_trace.py", line 118, in wrapper
outs.append(self.inner(*trace_inputs))
File "/home/x/archiconda3/envs/siammask38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1480, in _call_impl
return forward_call(*args, **kwargs)
File "/home/x/archiconda3/envs/siammask38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1467, in _slow_forward
result = self.forward(*input, **kwargs)
TypeError: forward() missing 1 required positional argument: 'input'
To be specific, i need help figuring out the exact set of input and output parameters for torch.onnx.export() to perform the conversion.