Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

add privateuse1 device type to pre forward hook of fsdp#149487

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Closed

Conversation

@garfield1997
Copy link
Contributor

@garfield1997garfield1997 commentedMar 19, 2025
edited by pytorch-botbot
Loading

add privateuse1 device type to pre forward hook of fsdp

cc@H-Huang@awgu@kwen2501@wanchaol@fegin@fduwjj@wz337@wconstab@d4l3k@c-p-i-o

@pytorch-bot
Copy link

pytorch-botbot commentedMar 19, 2025
edited
Loading

🔗 Helpful Links

🧪 See artifacts and rendered test results athud.pytorch.org/pr/149487

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ You can merge normally! (2 Unrelated Failures)

As of commite62bfb9 with merge base41c97a7 (image):

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-botpytorch-botbot added oncall: distributedAdd this issue/PR to distributed oncall triage queue release notes: distributed (fsdp)release notes category labelsMar 19, 2025
Comment on lines +130 to +140
if self._device.type in [
"cuda",
"hpu",
"xpu",
"mtia",
torch._C._get_privateuse1_backend_name(),
]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

@albanD@fffrog Please take a look, thanks!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Why not check for accelerator here?

Copy link
ContributorAuthor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

I have added a check here; please take a look and see if this works.@albanD

@garfield1997
Copy link
ContributorAuthor

@pytorchbot merge

@pytorch-bot
Copy link

Pull workflow has not been scheduled for the PR yet. It could be because author doesn't have permissions to run those or skip-checks keywords were added to PR/commits, aborting merge. Please get/give approval for the workflows and/or remove skip ci decorators before next merge attempt. If you think this is a mistake, please contact PyTorch Dev Infra.

@garfield1997
Copy link
ContributorAuthor

@albanD If you have time, please take a look, thanks!

@cyyever
Copy link
Collaborator

@pytorchbot rebase

pytorch-bot[bot] reacted with thumbs up emoji

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job ontorefs/remotes/origin/viable/strict. Check the current statushere

@pytorchmergebot
Copy link
Collaborator

Successfully rebasedfsdp_device_type ontorefs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, viagit checkout fsdp_device_type && git pull --rebase)

@cyyever
Copy link
Collaborator

@pytorchbot merge

pytorch-bot[bot] reacted with thumbs up emoji

@pytorch-botpytorch-botbot added the ciflow/trunkTrigger trunk jobs on your pull request labelMar 31, 2025
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: Approvers from one of the following sets are needed:

  • Distributed (wconstab, mrshenli, pritamdamania87, zhaojuanmao, rohan-varma, ...)
  • superuser (pytorch/metamates)
  • Core Reviewers (mruberry, lezcano, Skylion007, ngimel, peterbell10, ...)
  • Core Maintainers (soumith, gchanan, ezyang, dzhulgakov, malfet, ...)
Details for Dev Infra teamRaised byworkflow job

Failing merge rule: Core Maintainers

@garfield1997
Copy link
ContributorAuthor

@pytorchbot merge

pytorch-bot[bot] reacted with thumbs up emoji

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: Approvers from one of the following sets are needed:

  • Distributed (wconstab, mrshenli, pritamdamania87, zhaojuanmao, rohan-varma, ...)
  • superuser (pytorch/metamates)
  • Core Reviewers (mruberry, lezcano, Skylion007, ngimel, peterbell10, ...)
  • Core Maintainers (soumith, gchanan, ezyang, dzhulgakov, malfet, ...)
Details for Dev Infra teamRaised byworkflow job

Failing merge rule: Core Maintainers

@pytorch-botpytorch-botbot removed the ciflow/trunkTrigger trunk jobs on your pull request labelApr 8, 2025
Copy link
Contributor

@shinkshink left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Thanks!

@cyyevercyyever requested a review fromalbanDApril 8, 2025 05:33
@garfield1997
Copy link
ContributorAuthor

@albanD If you have time, please take a look, thanks!

@albanD
Copy link
Collaborator

The lint failure is clearly related and I'm not sure about the other distributed tests. We should check that.

Also FYI@wconstab if you have any concern with this, if not, I'll accept it once the CI is green.

@garfield1997
Copy link
ContributorAuthor

@albanD It looks liketorch.accelerator.is_available can’t be traced, which caused some tests to fail. We should fall back to checking the device-type string instead.

Copy link
Collaborator

@albanDalbanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

What was the error exactly?
Sounds ok if CI is green

@garfield1997
Copy link
ContributorAuthor

garfield1997 commentedApr 18, 2025
edited
Loading

@albanD The test file and test case wastest/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_aot_eager and with environment variableTORCH_LOGS=+dynamo the error was

[rank3]:V0418 09:38:26.343000 1727598 torch/_dynamo/symbolic_convert.py:556] [0/0] [__graph_breaks]   Explanation: Dynamo developers have intentionally marked that the function `is_available` in file `/projs/framework/xushuo/pytorch_gar/venv/maingpu/lib/python3.10/site-packages/torch/accelerator/__init__.py` should not be traced.[rank3]:V0418 09:38:26.343000 1727598 torch/_dynamo/symbolic_convert.py:556] [0/0] [__graph_breaks]   Hint: Avoid calling the function `is_available`.[rank3]:V0418 09:38:26.343000 1727598 torch/_dynamo/symbolic_convert.py:556] [0/0] [__graph_breaks]   Hint: Apply `@torch._dynamo.dont_skip_tracing` to the function `is_available` to force tracing into the function. More graph breaks may occur as a result of attempting to trace into the function.[rank3]:V0418 09:38:26.343000 1727598 torch/_dynamo/symbolic_convert.py:556] [0/0] [__graph_breaks]   Hint: Please file an issue to PyTorch.

which was the cause of the following error

[rank1]:E0418 09:44:22.242000 1733492 torch/testing/_internal/common_distributed.py:741]  exiting process 1 with exit code: 10[rank0]:E0418 09:44:22.242000 1733491 torch/testing/_internal/common_distributed.py:741] Caught exception: [rank0]:E0418 09:44:22.242000 1733491 torch/testing/_internal/common_distributed.py:741] Traceback (most recent call last):[rank0]:E0418 09:44:22.242000 1733491 torch/testing/_internal/common_distributed.py:741]   File "/projs/framework/xushuo/pytorch_gar/venv/maingpu/lib/python3.10/site-packages/torch/testing/_internal/common_distributed.py", line 734, in run_test[rank0]:E0418 09:44:22.242000 1733491 torch/testing/_internal/common_distributed.py:741]     getattr(self, test_name)()[rank0]:E0418 09:44:22.242000 1733491 torch/testing/_internal/common_distributed.py:741]   File "/projs/framework/xushuo/pytorch_gar/venv/maingpu/lib/python3.10/site-packages/torch/testing/_internal/common_distributed.py", line 607, in wrapper[rank0]:E0418 09:44:22.242000 1733491 torch/testing/_internal/common_distributed.py:741]     fn()[rank0]:E0418 09:44:22.242000 1733491 torch/testing/_internal/common_distributed.py:741]   File "/projs/framework/xushuo/pytorch_gar/venv/maingpu/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 3154, in wrapper[rank0]:E0418 09:44:22.242000 1733491 torch/testing/_internal/common_distributed.py:741]     method(*args, **kwargs)[rank0]:E0418 09:44:22.242000 1733491 torch/testing/_internal/common_distributed.py:741]   File "/projs/framework/xushuo/pytorch_gar/venv/maingpu/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 1875, in wrapper[rank0]:E0418 09:44:22.242000 1733491 torch/testing/_internal/common_distributed.py:741]     return fn(*args, **kwargs)[rank0]:E0418 09:44:22.242000 1733491 torch/testing/_internal/common_distributed.py:741]   File "/projs/framework/xushuo/pytorch_gar/test/distributed/_composable/fsdp/test_fully_shard_compile.py", line 705, in test_nested_fully_shard_backend_aot_eager[rank0]:E0418 09:44:22.242000 1733491 torch/testing/_internal/common_distributed.py:741]     self._test_traceable_fsdp([rank0]:E0418 09:44:22.242000 1733491 torch/testing/_internal/common_distributed.py:741]   File "/projs/framework/xushuo/pytorch_gar/test/distributed/_composable/fsdp/test_fully_shard_compile.py", line 574, in _test_traceable_fsdp[rank0]:E0418 09:44:22.242000 1733491 torch/testing/_internal/common_distributed.py:741]     losses_compiled = test_compiled()[rank0]:E0418 09:44:22.242000 1733491 torch/testing/_internal/common_distributed.py:741]   File "/projs/framework/xushuo/pytorch_gar/test/distributed/_composable/fsdp/test_fully_shard_compile.py", line 540, in test_compiled[rank0]:E0418 09:44:22.242000 1733491 torch/testing/_internal/common_distributed.py:741]     self.assertEqual(len(counters["graph_break"]), 1)[rank0]:E0418 09:44:22.242000 1733491 torch/testing/_internal/common_distributed.py:741]   File "/projs/framework/xushuo/pytorch_gar/venv/maingpu/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 4095, in assertEqual[rank0]:E0418 09:44:22.242000 1733491 torch/testing/_internal/common_distributed.py:741]     raise error_metas.pop()[0].to_error(  # type: ignore[index][rank0]:E0418 09:44:22.242000 1733491 torch/testing/_internal/common_distributed.py:741] AssertionError: Scalars are not equal![rank0]:E0418 09:44:22.242000 1733491 torch/testing/_internal/common_distributed.py:741] [rank0]:E0418 09:44:22.242000 1733491 torch/testing/_internal/common_distributed.py:741] Expected 1 but got 2.[rank0]:E0418 09:44:22.242000 1733491 torch/testing/_internal/common_distributed.py:741] Absolute difference: 1[rank0]:E0418 09:44:22.242000 1733491 torch/testing/_internal/common_distributed.py:741] Relative difference: 1.0

@albanD
Copy link
Collaborator

@pytorchbot merge

Thanks for the details, I'll follow up with the dynamo team to see if we can do something about that!

pytorch-bot[bot] reacted with thumbs up emoji

@pytorch-botpytorch-botbot added the ciflow/trunkTrigger trunk jobs on your pull request labelApr 18, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in thewiki.

Questions? Feedback? Please reach out to thePyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for freeto join this conversation on GitHub. Already have an account?Sign in to comment

Reviewers

@albanDalbanDalbanD approved these changes

@cyyevercyyevercyyever approved these changes

@fffrogfffrogfffrog approved these changes

+1 more reviewer

@shinkshinkshink approved these changes

Reviewers whose approvals may not affect merge requirements

Assignees

No one assigned

Labels

ciflow/trunkTrigger trunk jobs on your pull requestMergedoncall: distributedAdd this issue/PR to distributed oncall triage queueopen sourcerelease notes: distributed (fsdp)release notes category

Projects

None yet

Milestone

No milestone

Development

Successfully merging this pull request may close these issues.

7 participants

@garfield1997@cyyever@pytorchmergebot@albanD@shink@fffrog@pytorchbot

[8]ページ先頭

©2009-2025 Movatter.jp