Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commite8de914

Browse files
Valentine233pytorchmergebot
authored andcommitted
[CPU][Flex attn] Add a readable error message for the backward path (#169646)
Fixes#169224.The flex attention does not support backward path on CPU.This PR adds a readable and meaningful error message for the case.Before:```Traceback (most recent call last): File "/workspace/test_flex_attn.py", line 24, in <module> output = flex_attention(query, key, value, block_mask=block_mask) File "/workspace/pytorch/torch/_dynamo/eval_frame.py", line 940, in compile_wrapper raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1 File "/workspace/pytorch/torch/_inductor/compile_fx.py", line 1019, in _compile_fx_inner raise InductorError(e, currentframe()).with_traceback( File "/workspace/pytorch/torch/_inductor/compile_fx.py", line 1003, in _compile_fx_inner mb_compiled_graph = fx_codegen_and_compile( File "/workspace/pytorch/torch/_inductor/compile_fx.py", line 1757, in fx_codegen_and_compile return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs) File "/workspace/pytorch/torch/_inductor/compile_fx.py", line 1452, in codegen_and_compile graph.run(*example_inputs) File "/workspace/pytorch/torch/_inductor/graph.py", line 987, in run return super().run(*args) File "/workspace/pytorch/torch/fx/interpreter.py", line 200, in run self.env[node] = self.run_node(node) File "/workspace/pytorch/torch/_inductor/graph.py", line 1726, in run_node result = super().run_node(n) File "/workspace/pytorch/torch/fx/interpreter.py", line 295, in run_node return getattr(self, n.op)(n.target, args, kwargs) File "/workspace/pytorch/torch/_inductor/graph.py", line 1257, in call_function return super().call_function(target, args, kwargs) File "/workspace/pytorch/torch/fx/interpreter.py", line 375, in call_function return target(*args, **kwargs)torch._inductor.exc.InductorError: IndexError: tuple index out of rangeSet TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"```After:```Traceback (most recent call last): File "/workspace/test_flex_attn.py", line 24, in <module> output = flex_attention(query, key, value, block_mask=block_mask) File "/workspace/pytorch/torch/_dynamo/eval_frame.py", line 926, in compile_wrapper return fn(*args, **kwargs) File "/workspace/pytorch/torch/nn/attention/flex_attention.py", line 1481, in flex_attention _validate_device(query, key, value) File "/workspace/pytorch/torch/nn/attention/flex_attention.py", line 1332, in _validate_device raise NotImplementedError(NotImplementedError: FlexAttention does not support backward on CPU. Please set the input requires_grad to False or use another device.```Pull Requestresolved:#169646Approved by:https://github.com/mingfeima,https://github.com/mlazos
1 parent13c036a commite8de914

File tree

7 files changed

+36
-12
lines changed

7 files changed

+36
-12
lines changed

‎test/distributed/pipelining/test_microbatch.py‎

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
fromtorch.nn.attention.flex_attentionimportcreate_block_mask,flex_attention
1313
fromtorch.testing._internal.common_device_typeimport (
1414
instantiate_device_type_tests,
15+
skipCPUIf,
1516
skipXPUIf,
1617
)
1718
fromtorch.testing._internal.common_utilsimportrun_tests,TestCase
@@ -60,6 +61,7 @@ def test_split_and_merge(self):
6061
torch.testing.assert_close(merged_kwargs,kwargs)
6162
print("Microbatch test passed")
6263

64+
@skipCPUIf(True,"Flex attention backward is not supported on CPU")
6365
deftest_split_block_mask(self,device):
6466
B=6
6567
H=1

‎test/dynamo/test_repros.py‎

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7427,9 +7427,9 @@ def dynamic_mask_function(batch_idx, head_idx, q_idx, kv_idx):
74277427
device=x.device,
74287428
_compile=False,
74297429
)
7430-
q=processed.view(batch_size,1,seq_len,self.dim)
7431-
k=processed.view(batch_size,1,seq_len,self.dim)
7432-
v=processed.view(batch_size,1,seq_len,self.dim)
7430+
q=processed.view(batch_size,1,seq_len,self.dim).detach()
7431+
k=processed.view(batch_size,1,seq_len,self.dim).detach()
7432+
v=processed.view(batch_size,1,seq_len,self.dim).detach()
74337433

74347434
out=torch.compile(flex_attention)(q,k,v,block_mask=block_mask)
74357435
out=flex_attention(q,k,v,block_mask=block_mask)

‎test/export/test_export.py‎

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -926,9 +926,21 @@ def dynamic_mask_function(batch_idx, head_idx, q_idx, kv_idx):
926926
KV_LEN=seq_len,
927927
device=x.device,
928928
)
929-
q = self.q_proj(processed).view(batch_size, 1, seq_len, self.dim)
930-
k = self.k_proj(processed).view(batch_size, 1, seq_len, self.dim)
931-
v = self.v_proj(processed).view(batch_size, 1, seq_len, self.dim)
929+
q = (
930+
self.q_proj(processed)
931+
.view(batch_size, 1, seq_len, self.dim)
932+
.detach()
933+
)
934+
k = (
935+
self.k_proj(processed)
936+
.view(batch_size, 1, seq_len, self.dim)
937+
.detach()
938+
)
939+
v = (
940+
self.v_proj(processed)
941+
.view(batch_size, 1, seq_len, self.dim)
942+
.detach()
943+
)
932944

933945
# Use flex_attention with torch.compile - during export, compile should be skipped
934946
backend = "inductor" if self.use_inductor else "eager"
@@ -1087,13 +1099,16 @@ def forward(self, x):
10871099
to_13 = torch.ops.aten.to.dtype(argsort_3, torch.int32, False, False, torch.contiguous_format); argsort_3 = None
10881100
linear_1 = torch.ops.aten.linear.default(linear, q_proj_weight, q_proj_bias); q_proj_weight = q_proj_bias = None
10891101
view_1 = torch.ops.aten.view.default(linear_1, [2, 1, 128, 64]); linear_1 = None
1102+
detach_19 = torch.ops.aten.detach.default(view_1); view_1 = None
10901103
linear_2 = torch.ops.aten.linear.default(linear, k_proj_weight, k_proj_bias); k_proj_weight = k_proj_bias = None
10911104
view_2 = torch.ops.aten.view.default(linear_2, [2, 1, 128, 64]); linear_2 = None
1105+
detach_20 = torch.ops.aten.detach.default(view_2); view_2 = None
10921106
linear_3 = torch.ops.aten.linear.default(linear, v_proj_weight, v_proj_bias); linear = v_proj_weight = v_proj_bias = None
10931107
view_3 = torch.ops.aten.view.default(linear_3, [2, 1, 128, 64]); linear_3 = None
1108+
detach_21 = torch.ops.aten.detach.default(view_3); view_3 = None
10941109
sdpa_score0 = self.sdpa_score0
10951110
sdpa_mask0 = self.sdpa_mask0
1096-
flex_attention = torch.ops.higher_order.flex_attention(view_1, view_2, view_3, sdpa_score0, (128, 128, to_3, to_4, to_6, to_7, to_9, to_10, to_12, to_13, 128, 128, sdpa_mask0), 0.125, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': False, 'OUTPUT_MAX': False}, (), (detach,));view_1 =view_2 =view_3 = sdpa_score0 = to_3 = to_4 = to_6 = to_7 = to_9 = to_10 = to_12 = to_13 = sdpa_mask0 = detach = None
1111+
flex_attention = torch.ops.higher_order.flex_attention(detach_19, detach_20, detach_21, sdpa_score0, (128, 128, to_3, to_4, to_6, to_7, to_9, to_10, to_12, to_13, 128, 128, sdpa_mask0), 0.125, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': False, 'OUTPUT_MAX': False}, (), (detach,));detach_19 =detach_20 =detach_21 = sdpa_score0 = to_3 = to_4 = to_6 = to_7 = to_9 = to_10 = to_12 = to_13 = sdpa_mask0 = detach = None
10971112
getitem = flex_attention[0]
10981113
getitem_1 = flex_attention[1]; getitem_1 = None
10991114
getitem_2 = flex_attention[2]; flex_attention = getitem_2 = None

‎test/functorch/test_aotdispatch.py‎

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7599,6 +7599,7 @@ def _tg3(y):
75997599
self.assertEqual(ref_x.grad,x.grad)
76007600

76017601
@patch("torch._functorch.config.guess_tangent_strides_as_outputs",True)
7602+
@unittest.skipIf(nottorch.cuda.is_available(),"CUDA is unavailable")
76027603
deftest_flex_attn_noncontiguous_tangents(self):
76037604
withGradsNoForceContiguousContextManager()asctx:
76047605
E=16# embedding dim
@@ -7626,12 +7627,12 @@ def forward(self, x):
76267627

76277628
returny.transpose(1,2).contiguous().view(B,T,E)
76287629

7629-
m=M()
7630+
m=M().cuda()
76307631
B=1
76317632
T=8
76327633

76337634
def_inp():
7634-
returntorch.randn(B,T,E,requires_grad=True)
7635+
returntorch.randn(B,T,E,requires_grad=True,device="cuda")
76357636

76367637
x=_inp()
76377638
y=m(x)

‎test/inductor/test_flex_attention.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1863,7 +1863,7 @@ def score_mod_func(score, b, h, q, kv):
18631863
(2,2,128,4),
18641864
device=device,
18651865
dtype=torch.float64,
1866-
requires_grad=True,
1866+
requires_grad=False,
18671867
)
18681868
query,key,value=make_tensor(),make_tensor(),make_tensor()
18691869
# floor_div is not decomposed in decomposition_table is empty

‎test/inductor/test_flex_decoding.py‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1176,14 +1176,14 @@ def score_mod_func(score, b, h, q, kv):
11761176
(2,2,128,4),
11771177
dtype=dtype,
11781178
device=device,
1179-
requires_grad=True,
1179+
requires_grad=False,
11801180
)
11811181
make_q=functools.partial(
11821182
torch.randn,
11831183
(2,2,8,4),
11841184
dtype=dtype,
11851185
device=device,
1186-
requires_grad=True,
1186+
requires_grad=False,
11871187
)
11881188
query,key,value=make_q(),make_kv(),make_kv()
11891189
# floor_div is not decomposed in decomposition_table is empty

‎torch/nn/attention/flex_attention.py‎

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1322,6 +1322,12 @@ def _validate_device(query: Tensor, key: Tensor, value: Tensor) -> None:
13221322
"""TODO: Remove once non cuda/cpu devices support is added
13231323
We only need to check query since we have already that q,k,v are on the same device
13241324
"""
1325+
ifquery.device.type=="cpu"and (
1326+
query.requires_gradorkey.requires_gradorvalue.requires_grad
1327+
):
1328+
raiseNotImplementedError(
1329+
"FlexAttention does not support backward on CPU. Please set the input requires_grad to False or use another device."
1330+
)
13251331
supported_devices= {"cuda","cpu","xpu","hpu"}
13261332
ifquery.device.typenotinsupported_devices:
13271333
raiseValueError(

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp