@@ -96,7 +96,7 @@ def temp_float32_matmul_precision(precision: str):
9696
9797def skip_on_cpu (test_func ):
9898"""Decorator to skip tests that are not supported on CPU."""
99- decorated_func = skipCPUIf (True ,"Not supported onCUDA " )(test_func )
99+ decorated_func = skipCPUIf (True ,"Not supported onCPU " )(test_func )
100100return decorated_func
101101
102102
@@ -2851,6 +2851,7 @@ def test_strided_backwards(self):
28512851 (1 ,0 ,2 ,3 ),# Reverse order
28522852 (0 ,2 ,1 ,3 ),# Mixed order
28532853 (2 ,0 ,1 ,3 ),# Another mixed order
2854+ (0 ,1 ,3 ,2 ),# Non contiguous last dim
28542855 ],
28552856 )
28562857@common_utils .parametrize ("shape" , [(2 ,1 ,128 ,16 ), (4 ,2 ,64 ,16 )])
@@ -2899,12 +2900,7 @@ def test_flex_attention_stride_ordering(self, device, mode, permute_order, shape
28992900@common_utils .parametrize ("mode" , ["eager" ,"inductor" ])
29002901@common_utils .parametrize (
29012902"permute_order" ,
2902- [
2903- (0 ,1 ,2 ,3 ),
2904- (1 ,0 ,2 ,3 ),
2905- (0 ,2 ,1 ,3 ),
2906- (2 ,0 ,1 ,3 ),
2907- ],
2903+ [(0 ,1 ,2 ,3 ), (1 ,0 ,2 ,3 ), (0 ,2 ,1 ,3 ), (2 ,0 ,1 ,3 ), (0 ,1 ,3 ,2 )],
29082904 )
29092905@common_utils .parametrize ("shape" , [(2 ,5 ,128 ,16 ), (4 ,2 ,64 ,16 )])
29102906def test_flex_attention_backward_stride_ordering (
@@ -2948,6 +2944,69 @@ def test_flex_attention_backward_stride_ordering(
29482944f"Mode:{ mode } , Stride order mismatch for{ name } : grad{ input_stride_order } , input{ orig_stride_order } ." ,
29492945 )
29502946
2947+ @supported_platform
2948+ def test_non_contiguous_last_dim (self ,device ):
2949+ """Test flex_attention with tensors having non contiguous last dimension."""
2950+ B ,H ,D = 4 ,8 ,64
2951+ dtype = torch .float16 if device == "cuda" else torch .float32
2952+ for S in [16 ,64 ]:
2953+
2954+ def column_major_tensor ():
2955+ tensor = torch .randn (
2956+ (B ,H ,S ,D ),
2957+ dtype = dtype ,
2958+ device = device ,
2959+ )
2960+ # Column major in last 2 dims
2961+ return tensor .transpose (- 1 ,- 2 ).contiguous ().transpose (- 1 ,- 2 )
2962+
2963+ q = column_major_tensor ()
2964+ k = column_major_tensor ()
2965+ v = column_major_tensor ()
2966+
2967+ requires_grad = device in DEVICE_SUPPORTS_BACKWARDS
2968+ if requires_grad :
2969+ q .requires_grad_ (True )
2970+ k .requires_grad_ (True )
2971+ v .requires_grad_ (True )
2972+
2973+ self .assertNotEqual (q .stride ()[- 1 ],1 )
2974+ self .assertNotEqual (k .stride ()[- 1 ],1 )
2975+ self .assertNotEqual (v .stride ()[- 1 ],1 )
2976+
2977+ q_ref ,k_ref ,v_ref = query_key_value_clones (q ,k ,v )
2978+ q_gold ,k_gold ,v_gold = query_key_value_clones (q ,k ,v ,torch .float64 )
2979+
2980+ golden_out = flex_attention (q_gold ,k_gold ,v_gold )
2981+ ref_out = flex_attention (q_ref ,k_ref ,v_ref )
2982+
2983+ flex_compiled = torch .compile (flex_attention ,fullgraph = True ,dynamic = True )
2984+ compiled_out = flex_compiled (q ,k ,v )
2985+
2986+ self ._check_out (golden_out ,ref_out ,compiled_out )
2987+
2988+ if requires_grad :
2989+ backward_grad = torch .randn_like (ref_out )
2990+
2991+ golden_out .backward (backward_grad .to (torch .float64 ))
2992+ ref_out .backward (backward_grad )
2993+ compiled_out .backward (backward_grad )
2994+
2995+ self ._check_out_and_grad (
2996+ golden_out ,
2997+ ref_out ,
2998+ compiled_out ,
2999+ q_gold ,
3000+ q_ref ,
3001+ q ,
3002+ k_gold ,
3003+ k_ref ,
3004+ k ,
3005+ v_gold ,
3006+ v_ref ,
3007+ v ,
3008+ )
3009+
29513010@supported_platform
29523011@common_utils .parametrize ("compile" , [True ,False ])
29533012def test_fully_masked_out_rows_0_check (self ,device ,compile :bool ):