Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitff2f9ae

Browse files
committed
fix: add explicit cast for i64 outputs as they may not be supported in
all layersSigned-off-by: Naren Dasan <naren@narendasan.com>Signed-off-by: Naren Dasan <narens@nvidia.com>
1 parentda25720 commitff2f9ae

File tree

3 files changed

+24
-10
lines changed

3 files changed

+24
-10
lines changed

‎py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py‎

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
fromtypingimportAny,Callable,Dict,List,NamedTuple,Optional,Sequence,Set
55

66
importnumpyasnp
7-
importtensorrtastrt
87
importtorch
98
importtorch.fx
109
fromtorch.fx.nodeimport_get_qualified_name
@@ -26,6 +25,7 @@
2625
fromtorch_tensorrt.fx.observerimportObserver
2726
fromtorch_tensorrt.loggingimportTRT_LOGGER
2827

28+
importtensorrtastrt
2929
frompackagingimportversion
3030

3131
_LOGGER:logging.Logger=logging.getLogger(__name__)
@@ -498,6 +498,9 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]:
498498
)
499499

500500
fori,outputinenumerate(outputs):
501+
name=f"output{i}"
502+
503+
output_dtype=dtype.unknown
501504
ifany(
502505
op_nameinoutput.name.split("_")
503506
forop_namein (
@@ -514,16 +517,20 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]:
514517
"any",
515518
)
516519
):
517-
output_bool=True
518-
else:
519-
output_bool=False
520-
name=f"output{i}"
521-
output.name=name
522-
self.ctx.net.mark_output(output)
523-
ifoutput_bool:
524-
output.dtype=trt.DataType.BOOL
520+
output_dtype=dtype.b
525521
elifself.output_dtypesisnotNone:
526-
output.dtype=self.output_dtypes[i].to(trt.DataType)
522+
ifself.output_dtypes[i]==dtype.i64:
523+
output=self.ctx.net.add_cast(
524+
output,dtype.i64.to(trt.DataType)
525+
).get_output(0)
526+
output_dtype=dtype.i64
527+
else:
528+
output_dtype=self.output_dtypes[i]
529+
530+
self.ctx.net.mark_output(output)
531+
ifoutput_dtypeisnotdtype.unknown:
532+
output.dtype=output_dtype.to(trt.DataType,use_default=True)
533+
output.name=name
527534

528535
self._output_names.append(name)
529536
_LOGGER.debug(

‎py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py‎

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# mypy: disallow-untyped-decorators=False
2+
13
importlogging
24
importoperator
35
fromtypingimportAny,Callable,Dict,Optional,Sequence,Tuple,Union
@@ -858,6 +860,7 @@ def validate_dtype(to_copy_node: Node) -> bool:
858860
allowed_casts= {
859861
torch.float,
860862
torch.int32,
863+
torch.int64,
861864
torch.bool,
862865
torch.int8,
863866
torch.float16,

‎tests/py/dynamo/conversion/harness.py‎

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,10 @@ def run_test(
251251
truncate_double=compilation_settings.truncate_double,
252252
)
253253

254+
_LOGGER.debug(f"Compilation settings:{compilation_settings}")
255+
_LOGGER.debug(f"Inputs:{input_specs}")
256+
_LOGGER.debug(f"Output types:{output_dtypes}")
257+
254258
interp=TRTInterpreter(
255259
mod,
256260
input_specs,

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp