- Notifications
You must be signed in to change notification settings - Fork26.3k
[DeviceMesh] Fix error in fake-mode + TORCH_DISTRIBUTED_DEBUG#170765
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.
Already on GitHub?Sign in to your account
base:main
Are you sure you want to change the base?
Conversation
…TAIL`I got this error:```pyTraceback (most recent call last): File "train.py", line 1463, in <module> main(exit_stack) File "train.py", line 400, in main loss = model(inp) ^^^^^^^^^^ File "/my_conda_env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1774, in _wrapped_call_impl return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/my_conda_env/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 943, in compile_wrapper raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/my_conda_env/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 2437, in _call_user_compiler raise BackendCompilerFailed( File "/my_conda_env/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 2412, in _call_user_compiler compiled_fn = compiler_fn(gm, example_inputs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/my_conda_env/lib/python3.11/site-packages/torch/_dynamo/repro/after_dynamo.py", line 156, in __call__ compiled_gm = compiler_fn(gm, example_inputs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/my_conda_env/lib/python3.11/site-packages/torch/__init__.py", line 2435, in __call__ return compile_fx(model_, inputs_, config_patches=self.config) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/my_conda_env/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 2477, in compile_fx return compile_fx( ^^^^^^^^^^^ File "/my_conda_env/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 2528, in compile_fx return _maybe_wrap_and_compile_fx_main( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/my_conda_env/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 2605, in _maybe_wrap_and_compile_fx_main return _compile_fx_main( ^^^^^^^^^^^^^^^^^ File "/my_conda_env/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 2800, in _compile_fx_main return aot_autograd( ^^^^^^^^^^^^^ File "/my_conda_env/lib/python3.11/site-packages/torch/_dynamo/backends/common.py", line 123, in __call__ cg = aot_module_simplified(gm, example_inputs, **self.kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/my_conda_env/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1115, in aot_module_simplified compiled_fn, _ = aot_stage2_compile( ^^^^^^^^^^^^^^^^^^^ File "/my_conda_env/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/graph_compile.py", line 355, in aot_stage2_compile return aot_stage2_autograd(aot_state, aot_graph_capture) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/my_conda_env/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/graph_compile.py", line 2002, in aot_stage2_autograd fw_module_str, bw_module_str = _log_fw_bw_graphs( ^^^^^^^^^^^^^^^^^^ File "/my_conda_env/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/graph_compile.py", line 1499, in _log_fw_bw_graphs str(fw_metadata), ^^^^^^^^^^^^^^^^ File "/my_conda_env/lib/python3.11/dataclasses.py", line 240, in wrapper result = user_function(self) ^^^^^^^^^^^^^^^^^^^ File "<string>", line 3, in __repr__ File "/my_conda_env/lib/python3.11/dataclasses.py", line 240, in wrapper result = user_function(self) ^^^^^^^^^^^^^^^^^^^ File "<string>", line 3, in __repr__ File "/my_conda_env/lib/python3.11/dataclasses.py", line 240, in wrapper result = user_function(self) ^^^^^^^^^^^^^^^^^^^ File "<string>", line 3, in __repr__ File "/my_conda_env/lib/python3.11/site-packages/torch/distributed/device_mesh.py", line 518, in __repr__ device_mesh_repr += f", Mesh: {self.mesh.tolist()}" ^^^^^^^^^ File "/my_conda_env/lib/python3.11/site-packages/torch/distributed/device_mesh.py", line 308, in mesh full_mesh = self._layout.remap_to_tensor(self._rank_map) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/my_conda_env/lib/python3.11/site-packages/torch/distributed/_mesh_layout.py", line 306, in remap_to_tensor return rank_map.as_strided( ^^^^^^^^^^^^^^^^^^^^ File "/my_conda_env/lib/python3.11/site-packages/torch/utils/_stats.py", line 29, in wrapper return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "/my_conda_env/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1397, in __torch_dispatch__ return self.dispatch(func, types, args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/my_conda_env/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 2155, in dispatch return self._cached_dispatch_impl(func, types, args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/my_conda_env/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1510, in _cached_dispatch_impl return self._dispatch_impl(func, types, args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/my_conda_env/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 2451, in _dispatch_impl (flat_args, flat_arg_fake_tensors) = self.validate_and_convert_non_fake_tensors( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/my_conda_env/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 2909, in validate_and_convert_non_fake_tensors validated_args = [validate(a) for a in flat_args] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/my_conda_env/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 2909, in <listcomp> validated_args = [validate(a) for a in flat_args] ^^^^^^^^^^^ File "/my_conda_env/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 2897, in validate raise AssertionError(torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:AssertionError: Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode with 'allow_non_fake_inputs'. Found in aten.as_strided.default(...)Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"```pytorch-botbot commentedDec 18, 2025 • edited
Loading Uh oh!
There was an error while loading.Please reload this page.
edited
Uh oh!
There was an error while loading.Please reload this page.
🔗 Helpful Links🧪 See artifacts and rendered test results athud.pytorch.org/pr/170765
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit7b776e3 with merge basedbba85b ( BROKEN TRUNK - The following job failed but were present on the merge base:👉Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| "either have all its original dimensions (e.g., no slicing) " | ||
| "or it needs to contain the local rank" | ||
| ) | ||
| withtorch._subclasses.fake_tensor.unset_fake_temporarily(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
so what is the behavior of tracing a call to .mesh in dynamo?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
The way I see it, it should not be part of the public interface that DeviceMesh is internally using Tensors and, in fact, it's doing so less and less (thanks to CuTe layouts) and themesh property is mainly a legacy feature.
Moreover, I expect that all the operations on DeviceMeshes end up being completely "desugared" in the Dynamo graph, after they have helped introduce the correct collectives.
Finally, DeviceMeshe's internal Tensors arealways on CPU, and their values are never data-dependent or anything, thus I don't see any point in supporting fake tensors...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Ahh interesting, we also use this one inside__getitem__. But since now we use layout I actually think we can remove that one.
fduwjj commentedDec 18, 2025
Can you kindly add a unit test for it? So that we can capture cases like this down the road. |
When setting
TORCH_DISTRIBUTED_DEBUG=DETAILI got this error: