You signed in with another tab or window.Reload to refresh your session.You signed out in another tab or window.Reload to refresh your session.You switched accounts on another tab or window.Reload to refresh your session.Dismiss alert
[CPU][Flex attn] Add a readable error message for the backward path (#169646)
Fixes#169224.The flex attention does not support backward path on CPU.This PR adds a readable and meaningful error message for the case.Before:```Traceback (most recent call last): File "/workspace/test_flex_attn.py", line 24, in <module> output = flex_attention(query, key, value, block_mask=block_mask) File "/workspace/pytorch/torch/_dynamo/eval_frame.py", line 940, in compile_wrapper raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1 File "/workspace/pytorch/torch/_inductor/compile_fx.py", line 1019, in _compile_fx_inner raise InductorError(e, currentframe()).with_traceback( File "/workspace/pytorch/torch/_inductor/compile_fx.py", line 1003, in _compile_fx_inner mb_compiled_graph = fx_codegen_and_compile( File "/workspace/pytorch/torch/_inductor/compile_fx.py", line 1757, in fx_codegen_and_compile return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs) File "/workspace/pytorch/torch/_inductor/compile_fx.py", line 1452, in codegen_and_compile graph.run(*example_inputs) File "/workspace/pytorch/torch/_inductor/graph.py", line 987, in run return super().run(*args) File "/workspace/pytorch/torch/fx/interpreter.py", line 200, in run self.env[node] = self.run_node(node) File "/workspace/pytorch/torch/_inductor/graph.py", line 1726, in run_node result = super().run_node(n) File "/workspace/pytorch/torch/fx/interpreter.py", line 295, in run_node return getattr(self, n.op)(n.target, args, kwargs) File "/workspace/pytorch/torch/_inductor/graph.py", line 1257, in call_function return super().call_function(target, args, kwargs) File "/workspace/pytorch/torch/fx/interpreter.py", line 375, in call_function return target(*args, **kwargs)torch._inductor.exc.InductorError: IndexError: tuple index out of rangeSet TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"```After:```Traceback (most recent call last): File "/workspace/test_flex_attn.py", line 24, in <module> output = flex_attention(query, key, value, block_mask=block_mask) File "/workspace/pytorch/torch/_dynamo/eval_frame.py", line 926, in compile_wrapper return fn(*args, **kwargs) File "/workspace/pytorch/torch/nn/attention/flex_attention.py", line 1481, in flex_attention _validate_device(query, key, value) File "/workspace/pytorch/torch/nn/attention/flex_attention.py", line 1332, in _validate_device raise NotImplementedError(NotImplementedError: FlexAttention does not support backward on CPU. Please set the input requires_grad to False or use another device.```Pull Requestresolved:#169646Approved by:https://github.com/mingfeima,https://github.com/mlazos