Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit6a974be

Browse files
ezyangpytorchmergebot
authored andcommitted
Change flash attention outputs to be SymInt instead of int (#110533)
Fixes#110322Signed-off-by: Edward Z. Yang <ezyang@meta.com>Pull Requestresolved:#110533Approved by:https://github.com/albanD
1 parentf1d8113 commit6a974be

File tree

10 files changed

+65
-15
lines changed

10 files changed

+65
-15
lines changed

‎aten/src/ATen/native/native_functions.yaml‎

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14349,14 +14349,14 @@
1434914349
variants: function
1435014350
tags: nondeterministic_seeded
1435114351

14352-
- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k,int max_q,int max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
14352+
- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k,SymInt max_q,SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
1435314353
dispatch:
1435414354
CPU: _scaled_dot_product_flash_attention_cpu
1435514355
CUDA: _scaled_dot_product_flash_attention_cuda
1435614356
NestedTensorCUDA: _scaled_dot_product_flash_attention_nestedtensor_cuda
1435714357
tags: nondeterministic_seeded
1435814358

14359-
- func: _scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k,int max_q,int max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value)
14359+
- func: _scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k,SymInt max_q,SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value)
1436014360
device_check: NoCheck
1436114361
variants: function
1436214362
dispatch:
@@ -14375,13 +14375,13 @@
1437514375
CUDA: _scaled_dot_product_efficient_attention_backward_cuda
1437614376
tags: nondeterministic_seeded
1437714377

14378-
- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k,int? max_q,int? max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
14378+
- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k,SymInt? max_q,SymInt? max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
1437914379
variants: function
1438014380
dispatch:
1438114381
CUDA: _flash_attention_forward
1438214382
tags: nondeterministic_seeded
1438314383

14384-
- func: _flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k,int max_q,int max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor, Tensor, Tensor)
14384+
- func: _flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k,SymInt max_q,SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor, Tensor, Tensor)
1438514385
device_check: NoCheck
1438614386
variants: function
1438714387
dispatch:

‎aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,8 @@ std::tuple<
220220
Tensor,
221221
Tensor,
222222
Tensor,
223-
int64_t,
224-
int64_t,
223+
c10::SymInt,
224+
c10::SymInt,
225225
Tensor,
226226
Tensor,
227227
Tensor>

‎aten/src/ATen/native/transformers/attention.cpp‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -744,8 +744,8 @@ std::tuple<
744744
at::Tensor,
745745
at::Tensor,
746746
at::Tensor,
747-
int64_t,
748-
int64_t,
747+
c10::SymInt,
748+
c10::SymInt,
749749
at::Tensor,
750750
at::Tensor,
751751
at::Tensor>

‎aten/src/ATen/native/transformers/cuda/attention.cu‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -668,7 +668,7 @@ std::tuple<Tensor, Tensor> native_multi_head_attention_cuda(
668668
}
669669
returnstd::make_tuple(std::move(proj),std::move(qkt));
670670
}
671-
std::tuple<Tensor, Tensor, Tensor, Tensor,int64_t,int64_t, Tensor, Tensor, Tensor>_scaled_dot_product_flash_attention_cuda(
671+
std::tuple<Tensor, Tensor, Tensor, Tensor,c10::SymInt, c10::SymInt, Tensor, Tensor, Tensor>_scaled_dot_product_flash_attention_cuda(
672672
const Tensor& query,
673673
const Tensor& key,
674674
const Tensor& value,

‎test/inductor/test_cuda_repro.py‎

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@
55

66
importtorch
77
importtorch._dynamo.configasdynamo_config
8+
importtorch.backends.cuda
9+
importtorch.nn.functionalasF
810
fromtorchimportnn
911
fromtorch._dynamo.debug_utilsimportsame_two_models
1012
fromtorch._dynamo.testingimportrand_strided
1113
fromtorch._dynamo.utilsimportsame
1214
fromtorch._inductorimportconfig
1315
fromtorch._inductor.compile_fximportcompile_fx_inner
1416
fromtorch.fx.experimental.proxy_tensorimportmake_fx
17+
fromtorch.testing._internal.common_cudaimportPLATFORM_SUPPORTS_FLASH_ATTENTION
1518
fromtorch.testing._internal.common_utilsimport (
1619
DeterministicGuard,
1720
freeze_rng_state,
@@ -982,6 +985,51 @@ def fn(x, y, z):
982985

983986
self.assertEqual(ref,res)
984987

988+
@unittest.skipIf(
989+
notPLATFORM_SUPPORTS_FLASH_ATTENTION,"flash attention not supported"
990+
)
991+
deftest_flash_attention_dynamic(self):
992+
classModel(nn.Module):
993+
def__init__(self,*args,**kwargs)->None:
994+
super().__init__(*args,**kwargs)
995+
996+
self.q=nn.Linear(1024,1024)
997+
self.k=nn.Linear(1024,1024)
998+
self.v=nn.Linear(1024,1024)
999+
1000+
defforward(self,x):
1001+
batch_size,seq_len,_=x.size()
1002+
1003+
queries=self.q(x).view(batch_size,seq_len,8,128).transpose(2,1)
1004+
keys=self.k(x).view(batch_size,seq_len,8,128).transpose(2,1)
1005+
values=self.v(x).view(batch_size,seq_len,8,128).transpose(2,1)
1006+
1007+
attn=F.scaled_dot_product_attention(
1008+
queries,
1009+
keys,
1010+
values,
1011+
)
1012+
1013+
returnattn
1014+
1015+
cnts=torch._dynamo.testing.CompileCounterWithBackend("inductor")
1016+
1017+
model=Model().cuda().half()
1018+
model=torch.compile(model,backend=cnts,dynamic=True)
1019+
1020+
withtorch.backends.cuda.sdp_kernel(
1021+
enable_flash=True,enable_math=False,enable_mem_efficient=False
1022+
):
1023+
input1=torch.rand(5,512,1024,device="cuda",dtype=torch.float16)
1024+
input2=torch.rand(5,513,1024,device="cuda",dtype=torch.float16)
1025+
input3=torch.rand(5,514,1024,device="cuda",dtype=torch.float16)
1026+
1027+
out1=model(input1)
1028+
out2=model(input2)
1029+
out3=model(input3)
1030+
1031+
self.assertEqual(cnts.frame_count,1)
1032+
9851033
@config.patch({"triton.cudagraphs":True})
9861034
deftest_index_put_no_fallback_cudagraph(self):
9871035
deffn(x,y,z):

‎tools/autograd/derivatives.yaml‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2764,9 +2764,9 @@
27642764
output_differentiability:[True, False, False, False]
27652765
query, key, value, attn_bias:_scaled_dot_product_efficient_attention_backward(grad, query, key, value, attn_bias, output, log_sumexp, philox_seed, philox_offset, dropout_p, grad_input_mask, is_causal, scale)
27662766

2767-
-name:_scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k,int max_q,int max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
2767+
-name:_scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k,SymInt max_q,SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
27682768
output_differentiability:[True, False, False, False, False, False, False, False, False]
2769-
query, key, value:_scaled_dot_product_flash_attention_backward(grad, query, key, value, output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale)
2769+
query, key, value:_scaled_dot_product_flash_attention_backward_symint(grad, query, key, value, output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale)
27702770

27712771
# - name: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, int? max_q, int? max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None) -> (Tensor output, Tensor query_padded, Tensor key_padded, Tensor value_padded, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
27722772
# output_differentiability: [True, False, False, False, False, False, False, False]

‎torch/_C/return_types.pyi.in‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ from typing import (
1616
Union,
1717
)
1818

19-
from torch import contiguous_format, Generator, inf, memory_format, strided, Tensor
19+
from torch import contiguous_format, Generator, inf, memory_format, strided, Tensor, SymInt
2020
from torch.types import (
2121
_bool,
2222
_device,

‎torch/_inductor/ir.py‎

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3984,6 +3984,8 @@ def generate_output(output, indices):
39843984
)
39853985
elifisinstance(output,int):
39863986
returnoutput
3987+
elifisinstance(output,torch.SymInt):
3988+
returnoutput.node.expr
39873989
else:
39883990
assert (
39893991
outputisNone

‎torch/csrc/inductor/aoti_torch/shim_common.cpp‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,8 @@ AOTITorchError aoti_torch__scaled_dot_product_flash_attention(
228228
at::Tensor* ret3_tensor =newat::Tensor(std::move(r3));
229229
*ret3 =tensor_pointer_to_tensor_handle(ret3_tensor);
230230
}
231-
*ret4 = r4;
232-
*ret5 = r5;
231+
*ret4 = r4.expect_int();
232+
*ret5 = r5.expect_int();
233233
at::Tensor* ret6_tensor =newat::Tensor(std::move(r6));
234234
*ret6 =tensor_pointer_to_tensor_handle(ret6_tensor);
235235
at::Tensor* ret7_tensor =newat::Tensor(std::move(r7));

‎torchgen/api/python.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1129,7 +1129,7 @@ def dispatch_lambda_arg(cpp_arg: Binding) -> DispatchLambdaArgument:
11291129
"::std::tuple<at::Tensor,::std::vector<at::Tensor>>",
11301130
"::std::vector<at::Tensor>",
11311131
# Needed for flash attention forw/backward
1132-
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,int64_t,int64_t,at::Tensor,at::Tensor,at::Tensor>",
1132+
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,c10::SymInt,c10::SymInt,at::Tensor,at::Tensor,at::Tensor>",
11331133
"at::Scalar",
11341134
"bool",
11351135
"int64_t",

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp