Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit5a779c6

Browse files
reenable back thor test (#3929)
1 parentca0765c commit5a779c6

File tree

12 files changed

+99
-96
lines changed

12 files changed

+99
-96
lines changed

‎py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py‎

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
fromtorch.fx.nodeimportArgument,Node,Target
1111
fromtorch_tensorrtimportENABLED_FEATURES
1212
fromtorch_tensorrt._featuresimportneeds_not_tensorrt_rtx
13-
fromtorch_tensorrt._utilsimportis_tensorrt_version_supported,is_thor
13+
fromtorch_tensorrt._utilsimportis_tensorrt_version_supported
1414
fromtorch_tensorrt.dynamo._settingsimportCompilationSettings
1515
fromtorch_tensorrt.dynamo._SourceIRimportSourceIR
1616
fromtorch_tensorrt.dynamo.conversionimportimpl
@@ -429,7 +429,7 @@ def index_nonbool_validator(
429429
node:Node,settings:Optional[CompilationSettings]=None
430430
)->bool:
431431
# for thor and tensorrt_rtx, we don't support boolean indices, due to nonzero op not supported
432-
ifis_thor()orENABLED_FEATURES.tensorrt_rtx:
432+
ifENABLED_FEATURES.tensorrt_rtx:
433433
index=node.args[1]
434434
forindinindex:
435435
ifindisnotNone:
@@ -3621,18 +3621,10 @@ def aten_ops_full(
36213621
)
36223622

36233623

3624-
defnonzero_validator(
3625-
node:Node,settings:Optional[CompilationSettings]=None
3626-
)->bool:
3627-
returnnotis_thor()
3628-
3629-
36303624
# currently nonzero is not supported for tensorrt_rtx
36313625
# TODO: lan to add the nonzero support once tensorrt_rtx team has added the support
3632-
# TODO: apbose to remove the capability validator once thor bug resolve in NGC
36333626
@dynamo_tensorrt_converter(
36343627
torch.ops.aten.nonzero.default,
3635-
capability_validator=nonzero_validator,
36363628
supports_dynamic_shapes=True,
36373629
requires_output_allocator=True,
36383630
)

‎tests/py/dynamo/conversion/test_arange_aten.py‎

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,10 @@
55
importtorch_tensorrt
66
fromparameterizedimportparameterized
77
fromtorch.testing._internal.common_utilsimportrun_tests
8-
fromtorch_tensorrt._utilsimportis_tegra_platform,is_thor
98

109
from .harnessimportDispatchTestCase
1110

1211

13-
@unittest.skipIf(
14-
is_thor()oris_tegra_platform(),
15-
"Skipped on Thor and Tegra platforms",
16-
)
1712
classTestArangeConverter(DispatchTestCase):
1813
@parameterized.expand(
1914
[

‎tests/py/dynamo/conversion/test_cumsum_aten.py‎

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,10 @@
55
importtorch_tensorrt
66
fromparameterizedimportparameterized
77
fromtorch.testing._internal.common_utilsimportrun_tests
8-
fromtorch_tensorrt._utilsimportis_tegra_platform,is_thor
98

109
from .harnessimportDispatchTestCase
1110

1211

13-
@unittest.skipIf(
14-
is_thor()oris_tegra_platform(),
15-
"Skipped on Thor and Tegra platforms",
16-
)
1712
classTestCumsumConverter(DispatchTestCase):
1813
@parameterized.expand(
1914
[

‎tests/py/dynamo/conversion/test_index_aten.py‎

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
fromparameterizedimportparameterized
77
fromtorch.testing._internal.common_utilsimportrun_tests
88
fromtorch_tensorrtimportENABLED_FEATURES,Input
9-
fromtorch_tensorrt._utilsimportis_tegra_platform,is_thor
109

1110
from .harnessimportDispatchTestCase
1211

@@ -114,8 +113,8 @@ def forward(self, input):
114113
]
115114
)
116115
@unittest.skipIf(
117-
is_thor()orENABLED_FEATURES.tensorrt_rtx,
118-
"Skipped onThor ortensorrt_rtx due to nonzero not supported",
116+
ENABLED_FEATURES.tensorrt_rtx,
117+
"Skipped on tensorrt_rtx due to nonzero not supported",
119118
)
120119
deftest_index_constant_bool_mask(self,_,index,input):
121120
classTestModule(torch.nn.Module):
@@ -149,8 +148,8 @@ def forward(self, x, index0):
149148
)
150149

151150
@unittest.skipIf(
152-
is_thor()orENABLED_FEATURES.tensorrt_rtx,
153-
"Skipped onThor ortensorrt_rtx due to nonzero not supported",
151+
ENABLED_FEATURES.tensorrt_rtx,
152+
"Skipped on tensorrt_rtx due to nonzero not supported",
154153
)
155154
deftest_index_zero_two_dim_ITensor_mask(self):
156155
classTestModule(nn.Module):
@@ -163,10 +162,6 @@ def forward(self, x, index0):
163162
index0=torch.tensor([True,False])
164163
self.run_test(TestModule(), [input,index0],enable_passes=True)
165164

166-
@unittest.skipIf(
167-
is_thor(),
168-
"Skipped on Thor due to nonzero not supported",
169-
)
170165
deftest_index_zero_index_three_dim_ITensor(self):
171166
classTestModule(nn.Module):
172167
defforward(self,x,index0):
@@ -180,8 +175,8 @@ def forward(self, x, index0):
180175
self.run_test(TestModule(), [input,index0])
181176

182177
@unittest.skipIf(
183-
is_thor()orENABLED_FEATURES.tensorrt_rtx,
184-
"Skipped onThor ortensorrt_rtx due to nonzero not supported",
178+
ENABLED_FEATURES.tensorrt_rtx,
179+
"Skipped on tensorrt_rtx due to nonzero not supported",
185180
)
186181
deftest_index_zero_index_three_dim_mask_ITensor(self):
187182
classTestModule(nn.Module):
@@ -252,7 +247,7 @@ def forward(self, input):
252247

253248

254249
@unittest.skipIf(
255-
torch_tensorrt.ENABLED_FEATURES.tensorrt_rtxoris_thor()oris_tegra_platform(),
250+
torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx,
256251
"nonzero is not supported for tensorrt_rtx",
257252
)
258253
classTestIndexDynamicInputNonDynamicIndexConverter(DispatchTestCase):

‎tests/py/dynamo/conversion/test_nonzero_aten.py‎

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,12 @@
66
fromparameterizedimportparameterized
77
fromtorch.testing._internal.common_utilsimportrun_tests
88
fromtorch_tensorrtimportInput
9-
fromtorch_tensorrt._utilsimportis_tegra_platform,is_thor
109

1110
from .harnessimportDispatchTestCase
1211

1312

1413
@unittest.skipIf(
15-
torch_tensorrt.ENABLED_FEATURES.tensorrt_rtxoris_thor()oris_tegra_platform(),
14+
torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx,
1615
"nonzero is not supported for tensorrt_rtx",
1716
)
1817
classTestNonZeroConverter(DispatchTestCase):

‎tests/py/dynamo/conversion/test_sym_size.py‎

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,10 @@
44
importtorch.nnasnn
55
fromparameterizedimportparameterized
66
fromtorch.testing._internal.common_utilsimportrun_tests
7-
fromtorch_tensorrt._utilsimportis_thor
87

98
from .harnessimportDispatchTestCase
109

1110

12-
@unittest.skipIf(
13-
is_thor(),
14-
"Skipped on Thor",
15-
)
1611
classTestSymSizeConverter(DispatchTestCase):
1712
@parameterized.expand(
1813
[

‎tests/py/dynamo/models/test_export_kwargs_serde.py‎

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# type: ignore
22
importos
3-
importtempfile
43
importunittest
54

65
importpytest
@@ -22,7 +21,7 @@
2221

2322
@pytest.mark.unit
2423
@pytest.mark.critical
25-
deftest_custom_model():
24+
deftest_custom_model(tmpdir):
2625
classnet(nn.Module):
2726
def__init__(self):
2827
super().__init__()
@@ -75,15 +74,15 @@ def forward(self, x, b=5, c=None, d=None):
7574
)
7675

7776
# Save the module
78-
trt_ep_path=os.path.join(tempfile.gettempdir(),"compiled.ep")
77+
trt_ep_path=os.path.join(tmpdir,"compiled.ep")
7978
torchtrt.save(trt_gm,trt_ep_path,retrace=False)
8079
# Clean up model env
8180
torch._dynamo.reset()
8281

8382

8483
@pytest.mark.unit
8584
@pytest.mark.critical
86-
deftest_custom_model_with_dynamo_trace():
85+
deftest_custom_model_with_dynamo_trace(tmpdir):
8786
classnet(nn.Module):
8887
def__init__(self):
8988
super().__init__()
@@ -137,15 +136,15 @@ def forward(self, x, b=5, c=None, d=None):
137136
)
138137

139138
# Save the module
140-
trt_ep_path=os.path.join(tempfile.gettempdir(),"compiled.ep")
139+
trt_ep_path=os.path.join(tmpdir,"compiled.ep")
141140
torchtrt.save(trt_gm,trt_ep_path,retrace=False)
142141
# Clean up model env
143142
torch._dynamo.reset()
144143

145144

146145
@pytest.mark.unit
147146
@pytest.mark.critical
148-
deftest_custom_model_with_dynamo_trace_dynamic():
147+
deftest_custom_model_with_dynamo_trace_dynamic(tmpdir):
149148
classnet(nn.Module):
150149
def__init__(self):
151150
super().__init__()
@@ -208,15 +207,15 @@ def forward(self, x, b=5, c=None, d=None):
208207
)
209208

210209
# Save the module
211-
trt_ep_path=os.path.join(tempfile.gettempdir(),"compiled.ep")
210+
trt_ep_path=os.path.join(tmpdir,"compiled.ep")
212211
torchtrt.save(trt_gm,trt_ep_path,retrace=False)
213212
# Clean up model env
214213
torch._dynamo.reset()
215214

216215

217216
@pytest.mark.unit
218217
@pytest.mark.critical
219-
deftest_custom_model_with_dynamo_trace_kwarg_dynamic():
218+
deftest_custom_model_with_dynamo_trace_kwarg_dynamic(tmpdir):
220219
ir="dynamo"
221220

222221
classnet(nn.Module):
@@ -298,15 +297,15 @@ def forward(self, x, b=None, c=None, d=None, e=[]):
298297
msg=f"CustomKwargs Module TRT outputs don't match with the original model. Cosine sim score:{cos_sim} Threshold:{COSINE_THRESHOLD}",
299298
)
300299
# Save the module
301-
trt_ep_path=os.path.join(tempfile.gettempdir(),"compiled.ep")
300+
trt_ep_path=os.path.join(tmpdir,"compiled.ep")
302301
torchtrt.save(trt_gm,trt_ep_path,retrace=False)
303302
# Clean up model env
304303
torch._dynamo.reset()
305304

306305

307306
@pytest.mark.unit
308307
@pytest.mark.critical
309-
deftest_custom_model_with_dynamo_trace_kwarg_dynamic():
308+
deftest_custom_model_with_dynamo_trace_kwarg_dynamic(tmpdir):
310309
ir="dynamo"
311310

312311
classnet(nn.Module):
@@ -388,7 +387,7 @@ def forward(self, x, b=None, c=None, d=None, e=[]):
388387
msg=f"CustomKwargs Module TRT outputs don't match with the original model. Cosine sim score:{cos_sim} Threshold:{COSINE_THRESHOLD}",
389388
)
390389
# Save the module
391-
trt_ep_path=os.path.join(tempfile.gettempdir(),"compiled.ep")
390+
trt_ep_path=os.path.join(tmpdir,"compiled.ep")
392391
torchtrt.save(trt_gm,trt_ep_path,retrace=False)
393392
# Clean up model env
394393
torch._dynamo.reset()

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp