Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitcd6274e

Browse files
committed
[FlexAttention] Remove Old Constraint on last dim strides
ghstack-source-id:2dd8f81Pull Requestresolved:#151959
1 parentca17c81 commitcd6274e

File tree

2 files changed

+80
-12
lines changed

2 files changed

+80
-12
lines changed

‎test/inductor/test_flex_attention.py‎

Lines changed: 66 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def temp_float32_matmul_precision(precision: str):
9696

9797
defskip_on_cpu(test_func):
9898
"""Decorator to skip tests that are not supported on CPU."""
99-
decorated_func=skipCPUIf(True,"Not supported onCUDA")(test_func)
99+
decorated_func=skipCPUIf(True,"Not supported onCPU")(test_func)
100100
returndecorated_func
101101

102102

@@ -2851,6 +2851,7 @@ def test_strided_backwards(self):
28512851
(1,0,2,3),# Reverse order
28522852
(0,2,1,3),# Mixed order
28532853
(2,0,1,3),# Another mixed order
2854+
(0,1,3,2),# Non contiguous last dim
28542855
],
28552856
)
28562857
@common_utils.parametrize("shape", [(2,1,128,16), (4,2,64,16)])
@@ -2899,12 +2900,7 @@ def test_flex_attention_stride_ordering(self, device, mode, permute_order, shape
28992900
@common_utils.parametrize("mode", ["eager","inductor"])
29002901
@common_utils.parametrize(
29012902
"permute_order",
2902-
[
2903-
(0,1,2,3),
2904-
(1,0,2,3),
2905-
(0,2,1,3),
2906-
(2,0,1,3),
2907-
],
2903+
[(0,1,2,3), (1,0,2,3), (0,2,1,3), (2,0,1,3), (0,1,3,2)],
29082904
)
29092905
@common_utils.parametrize("shape", [(2,5,128,16), (4,2,64,16)])
29102906
deftest_flex_attention_backward_stride_ordering(
@@ -2948,6 +2944,69 @@ def test_flex_attention_backward_stride_ordering(
29482944
f"Mode:{mode}, Stride order mismatch for{name}: grad{input_stride_order}, input{orig_stride_order}.",
29492945
)
29502946

2947+
@supported_platform
2948+
deftest_non_contiguous_last_dim(self,device):
2949+
"""Test flex_attention with tensors having non contiguous last dimension."""
2950+
B,H,D=4,8,64
2951+
dtype=torch.float16ifdevice=="cuda"elsetorch.float32
2952+
forSin [16,64]:
2953+
2954+
defcolumn_major_tensor():
2955+
tensor=torch.randn(
2956+
(B,H,S,D),
2957+
dtype=dtype,
2958+
device=device,
2959+
)
2960+
# Column major in last 2 dims
2961+
returntensor.transpose(-1,-2).contiguous().transpose(-1,-2)
2962+
2963+
q=column_major_tensor()
2964+
k=column_major_tensor()
2965+
v=column_major_tensor()
2966+
2967+
requires_grad=deviceinDEVICE_SUPPORTS_BACKWARDS
2968+
ifrequires_grad:
2969+
q.requires_grad_(True)
2970+
k.requires_grad_(True)
2971+
v.requires_grad_(True)
2972+
2973+
self.assertNotEqual(q.stride()[-1],1)
2974+
self.assertNotEqual(k.stride()[-1],1)
2975+
self.assertNotEqual(v.stride()[-1],1)
2976+
2977+
q_ref,k_ref,v_ref=query_key_value_clones(q,k,v)
2978+
q_gold,k_gold,v_gold=query_key_value_clones(q,k,v,torch.float64)
2979+
2980+
golden_out=flex_attention(q_gold,k_gold,v_gold)
2981+
ref_out=flex_attention(q_ref,k_ref,v_ref)
2982+
2983+
flex_compiled=torch.compile(flex_attention,fullgraph=True,dynamic=True)
2984+
compiled_out=flex_compiled(q,k,v)
2985+
2986+
self._check_out(golden_out,ref_out,compiled_out)
2987+
2988+
ifrequires_grad:
2989+
backward_grad=torch.randn_like(ref_out)
2990+
2991+
golden_out.backward(backward_grad.to(torch.float64))
2992+
ref_out.backward(backward_grad)
2993+
compiled_out.backward(backward_grad)
2994+
2995+
self._check_out_and_grad(
2996+
golden_out,
2997+
ref_out,
2998+
compiled_out,
2999+
q_gold,
3000+
q_ref,
3001+
q,
3002+
k_gold,
3003+
k_ref,
3004+
k,
3005+
v_gold,
3006+
v_ref,
3007+
v,
3008+
)
3009+
29513010
@supported_platform
29523011
@common_utils.parametrize("compile", [True,False])
29533012
deftest_fully_masked_out_rows_0_check(self,device,compile:bool):

‎torch/_inductor/kernel/flex_attention.py‎

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -930,6 +930,15 @@ def check_cpu_supported():
930930
returnsupported
931931

932932

933+
defcontiguous_last_dim(x):
934+
"""Ensure that realized IR node has a contigous stride in the last dimension."""
935+
strides=x.maybe_get_stride()
936+
ifstridesandstrides[-1]!=1:
937+
contiguous_stride_order=list(reversed(range(len(x.get_size()))))
938+
returnExternKernel.require_stride_order(x,contiguous_stride_order)
939+
returnx
940+
941+
933942
deflower_cpu(
934943
query,
935944
key,
@@ -1092,6 +1101,9 @@ def convert_mask_graph_module(mask_graph):
10921101
ifisinstance(item,TensorBox):
10931102
fake_buffers.append(item.data.data)# type: ignore[attr-defined]
10941103

1104+
# CPU kernel requires last dim to be contiguous
1105+
query,key,value=map(contiguous_last_dim, [query,key,value])
1106+
10951107
(
10961108
query,
10971109
key,
@@ -1258,7 +1270,6 @@ def set_head_dim_values(
12581270
)
12591271

12601272

1261-
# TODO: We probably also need a layout constraint?
12621273
@register_lowering(torch.ops.higher_order.flex_attention,type_promotion_kind=None)
12631274
defflex_attention(
12641275
query,
@@ -1413,11 +1424,9 @@ def flex_attention(
14131424
else:
14141425
kernel_options.setdefault("IS_DIVISIBLE",True)
14151426

1416-
#Reuse query strides for output layout despite different last dimension.
1417-
#This works because only the last dim differs and we check it is contiguous.
1427+
#NB it is okay that the v_head_dim is different
1428+
#We are using these to match fill order of the output.
14181429
q_strides=query.get_stride()
1419-
assertq_strides[-1]==1,"Query must be contiguous in the last dimension"
1420-
14211430
# Construct output layout with strides matching the query.
14221431
out_size= [B,Hq,seq_len_q,v_head_dim]
14231432
out_strides=infer_dense_strides(out_size,q_strides)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp