@@ -6724,12 +6724,7 @@ def fn(values, same_size):
67246724check_results (fn ,compiled_fn ,generate_inp (20 ))
67256725self .assertEqual (compile_counter .frame_count ,frame_count_2 )
67266726
6727- # Note 1: Math fallback doesn't work with bfloat16 on CUDA
6728- # Note 2: ROCm doesn't support flash attention or mem_efficient attention for NT
6729- @unittest .skipIf (
6730- TEST_WITH_ROCM ,
6731- "ROCm doesn't support flash attention or mem_efficient attention for NT" ,
6732- )
6727+ # Note: Math fallback doesn't work with bfloat16 on CUDA
67336728@tf32_on_and_off (0.005 )
67346729@dtypes (
67356730* (
@@ -6999,9 +6994,7 @@ def check_forward_backward(skip_backward=False):
69996994
70006995@skipIfTorchDynamo ("SDPA test compiles internally" )
70016996@skipCUDAIf (not SM70OrLater ,"GPU capability is < SM70" )
7002- # Guarding with sqrt() doesn't work on ROCm?
70036997@xfailIfWindows
7004- @skipCUDAIfRocm
70056998@onlyCUDA
70066999@dtypes (
70077000* (
@@ -7188,8 +7181,6 @@ def in_proj(input_packed, qkv_linear=qkv_linear):
71887181@decorateIf (xfailIfWindows ,lambda params :params ["dtype" ]== torch .float32 )
71897182@skipIfTorchDynamo ("SDPA test compiles internally" )
71907183@skipCUDAIf (not SM70OrLater ,"GPU capability is < SM70" )
7191- # mha_varlen_fwd not supported on ROCm
7192- @skipCUDAIfRocm
71937184@onlyCUDA
71947185@dtypes (
71957186* (
@@ -7220,7 +7211,6 @@ def f(values, offsets):
72207211"Platform doesn't support flash or mem-efficient attention" ,
72217212 )
72227213@skipCUDAIf (not SM70OrLater ,"GPU capability is < SM70" )
7223- @skipCUDAIfRocm
72247214@onlyCUDA
72257215@skipIfTorchDynamo ()
72267216def test_sdpa_autocast (self ,device ):
@@ -7303,7 +7293,6 @@ def get_values():
73037293"Platform doesn't support flash or mem-efficient attention" ,
73047294 )
73057295@skipCUDAIf (not SM70OrLater ,"GPU capability is < SM70" )
7306- @skipCUDAIfRocm
73077296@onlyCUDA
73087297@skipIfTorchDynamo ()
73097298def test_sdpa_flop_counter (self ,device ):
@@ -7726,7 +7715,6 @@ def test_jagged_padded_dense_conversion_kernels(self, device, dtype):
77267715@dtypes (torch .float32 )
77277716@skipIfTorchDynamo ("Test compiles internally" )
77287717@skipCUDAIf (not SM70OrLater ,"GPU capability is < SM70" )
7729- @skipCUDAIfRocm
77307718def test_compile_preserves_metadata_cache (self ,device ,dtype ):
77317719# shape (B, *, D)
77327720nt = random_nt_from_dims (
@@ -7753,7 +7741,6 @@ def f(nt):
77537741@dtypes (torch .float32 )
77547742@skipIfTorchDynamo ("Test compiles internally" )
77557743@skipCUDAIf (not SM70OrLater ,"GPU capability is < SM70" )
7756- @skipCUDAIfRocm
77577744def test_compile_with_dynamic_max_seq_len (self ,device ,dtype ):
77587745# shape (B, *, D)
77597746# max seq len: 18
@@ -7786,7 +7773,6 @@ def f(nt):
77867773@dtypes (torch .float32 )
77877774@skipIfTorchDynamo ("Test compiles internally" )
77887775@skipCUDAIf (not SM70OrLater ,"GPU capability is < SM70" )
7789- @skipCUDAIfRocm
77907776def test_compile_with_dynamic_min_seq_len (self ,device ,dtype ):
77917777# shape (B, *, D)
77927778# min seq len: 7
@@ -7819,7 +7805,6 @@ def f(nt):
78197805@dtypes (torch .float32 )
78207806@skipIfTorchDynamo ("Test compiles internally" )
78217807@skipCUDAIf (not SM70OrLater ,"GPU capability is < SM70" )
7822- @skipCUDAIfRocm
78237808def test_compile_with_propagated_dynamic_max_seq_len (self ,device ,dtype ):
78247809# shape (B, *, D)
78257810# max seq len: 18
@@ -7946,7 +7931,6 @@ def test_to_padded_tensor(self, device, dtype, nt_dim, requires_grad):
79467931@torch ._dynamo .utils .disable_cache_limit ()
79477932@skipIfTorchDynamo ("SDPA test compiles internally" )
79487933@skipCUDAIf (not SM70OrLater ,"GPU capability is < SM70" )
7949- @skipCUDAIfRocm
79507934@dtypes (torch .float32 ,torch .double ,torch .half )
79517935@parametrize ("nt_dim" , [2 ,3 ,4 ])
79527936@parametrize ("requires_grad" , [False ,True ])
@@ -8048,7 +8032,6 @@ def _g(nt):
80488032@dtypes (torch .float32 )
80498033@skipIfTorchDynamo ("Test compiles internally" )
80508034@skipCUDAIf (not SM70OrLater ,"GPU capability is < SM70" )
8051- @skipCUDAIfRocm
80528035def test_compile_padded_dense_conversion_preserves_metadata_cache (
80538036self ,device ,dtype
80548037 ):