- Notifications
You must be signed in to change notification settings - Fork26.3k
[DTensor] Add DTensor redistribute fwd/bwd datatype conversion to enable SimpleFSDP mixed precision training#150740
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.
Already on GitHub?Sign in to your account
Uh oh!
There was an error while loading.Please reload this page.
Conversation
pytorch-botbot commentedApr 5, 2025 • edited
Loading Uh oh!
There was an error while loading.Please reload this page.
edited
Uh oh!
There was an error while loading.Please reload this page.
🔗 Helpful Links🧪 See artifacts and rendered test results athud.pytorch.org/pr/150740
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commitd8ddd63 with merge base9699cc3 ( FLAKY - The following job failed but was likely due to flakiness present on trunk:This comment was automatically generated by Dr. CI and updates every 15 minutes. |
linux-foundation-easyclabot commentedApr 5, 2025 • edited
Loading Uh oh!
There was an error while loading.Please reload this page.
edited
Uh oh!
There was an error while loading.Please reload this page.
wconstab left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
could you add a bit more context to the PR description?
- concise example / pseudocode for how this would be used (by fsdp)
- link to fsdp code with the full details
Uh oh!
There was an error while loading.Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
nit: we don't actually assert that reshard_replica_tensor is in dtype bf16? it could pass this test if it stays in f32, i think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
actually it didn't pass the test haha.reshard_replica_tensor isbfloat16, which throws datatype mismatch errors whenreshard_replica_tensor isfloat32.
The reason is inreplicate_to_replicate case, thecurrent_spec.placements andplacementshere is the same, andoutput is the castedlocal_tensor in bfloat16 hmmmm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Just FYIhttps://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py#L789-L801 shows thatassertEqual would do type cast when comparing two tensors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
sg, that's a good point.
Uh oh!
There was an error while loading.Please reload this page.
Uh oh!
There was an error while loading.Please reload this page.
Uh oh!
There was an error while loading.Please reload this page.
XilunWu left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
I wonder if this type conversion must happen inredistribute. Can this be implemented in a less intrusive way than changing DTensor API (e.g. as a hook)?
Uh oh!
There was an error while loading.Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Just FYIhttps://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py#L789-L801 shows thatassertEqual would do type cast when comparing two tensors.
Uh oh!
There was an error while loading.Please reload this page.
Uh oh!
There was an error while loading.Please reload this page.
torch/distributed/tensor/_api.py Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
If I were reading the API docstring, one thing I would like to understand what the difference is between this API anddtensor.to(torch.bfloat16).redistribute(...).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
This PR:
foward: fp32 -> bf16 -> all-gather
backward: bf16 -> fp32 -> reduce-scatter
Your code:
forward: fp32 -> bf16 -> all-gather
backward: bf16 -> reduce-scatter -> fp32
Uh oh!
There was an error while loading.Please reload this page.
tianyu-l commentedApr 8, 2025
Could you think of a concrete proposal? If it works in eager, would the backward hooks make it harder for torch.compile to generate full graph? |
wanchaol left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Nice feature added! I think the addition to the API make sense, however I think the mixed precision handling for autograd have some issues we need to resolve. Please see inlined comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Hmmm I think there're something wrong here. We should consider the fact thatforward_dtype andbackward_dtype are BOTH optional. This means that user might only provide aforward_dtype and expect this autograd function to handle the rest. At least we should handle the case whenfoward_dtype is only provided. For the case whenbackward_dtype provided, I am not sure what to do yet as there's no original dtensor grad dtype 🤔
In the current change, it does not handle the above case correctly. i.e.if one pass in fwd_dtype only, then forward behavior is right, but in backward we should also convert the grad_out AFTER redistribute back to the original dtensor dtype if it does not match the original dtensor dtype. Otherwise the precision of the DTensor and its grad would mismatch at a later stage
ruisizhang123Apr 8, 2025 • edited
Loading Uh oh!
There was an error while loading.Please reload this page.
edited
Uh oh!
There was an error while loading.Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
this is a very good point. for now since it's major use case is SimpleFSDP (which requires both fwd_dtype and bwd_dtype), can we enforce an assertion in API register, asking users to specific both fwd_dtype and bwd_dtype, if one of them is not None 🤔️
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
I think we should always cast back to the original dtype, even ifbackward_dtype is given.
See my comment at#150740 (comment)
It is debatable, ifbackward_dtype is not given, should we perform backward redistribute inforward_dtype or original dtype. I personally think it should be the latter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Yeah make sense, I think it should be original dtype, then let's convert back to the original dtype if the local tensor dtype is different.
Uh oh!
There was an error while loading.Please reload this page.
Uh oh!
There was an error while loading.Please reload this page.
ruisizhang123 commentedApr 8, 2025
Ya, iiuc we are avoiding hooks in SimpleFSDP because it is hard to make the order correct in full-graph tracing per thislink. |
Uh oh!
There was an error while loading.Please reload this page.
Uh oh!
There was an error while loading.Please reload this page.
Uh oh!
There was an error while loading.Please reload this page.
tianyu-l left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Left inline comments.
Also, please fix linting bylintrunner init andlintrunner -a.
Uh oh!
There was an error while loading.Please reload this page.
Uh oh!
There was an error while loading.Please reload this page.
Uh oh!
There was an error while loading.Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
| ifforward_dtypeisnotNone: | |
| ifforward_dtypeisnotNoneandforward_dtype!=input._local_tensor.dtype: |
To add the ciflow label This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows. |
tianyu-l left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
LGTM!
tianyu-l commentedApr 13, 2025
@pytorchbot merge |
pytorchmergebot commentedApr 13, 2025
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in thewiki. Questions? Feedback? Please reach out to thePyTorch DevX Team |
This PR adds mixed precision training support for SimpleFSDP, togetherwith this change from PyTorch:pytorch/pytorch#150740.**(This is a placeholder for now, and shall be merged only after thePyTorch pr is landed.)**#### Convergence on debug modelThe loss curves are all benchmarked with `training.seed = 42` on 4 GPUs.- Fully Shard Mode (['dp_shard'], [4]): FSDP2 and SimpleFSDP's lossesare perfectly matched<img width="2234" alt="Screenshot 2025-04-05 at 4 06 04 PM"src="https://github.com/user-attachments/assets/e0ec66f2-948f-43e2-8d88-20243fbaccfc"/>- Hybrid Shard Mode (['dp_replicate', 'dp_shard'], [2, 2]): FSDP2 andSimpleFSDP's losses are similar<img width="977" alt="Screenshot 2025-04-07 at 6 15 04 PM"src="https://github.com/user-attachments/assets/81c14637-e0a7-42aa-9b2e-c7092a287ea0"/>The end-to-end SimpleFSDP mixed precision training integration has beenproved to work properly in the PR from this fork:tianyu-l/pytorch_intern24#20.
…ble SimpleFSDP mixed precision training (pytorch#150740)As titled, this pr adds additional `forward_dtype` and `backward_dtype` conversion in DTensor `redistribute` API to enable SimpleFSDP's mixed precision training.In this forward pass, the DTensor can be configured to be cast to `forward_dtype`; in the backward pass, the DTensor can be configured to be cast to `backward_dtype`.1. **Correctness**: The end-to-end SimpleFSDP mixed precision training integration has been proved to work properly in the PR from this fork:tianyu-l#20. We are now migrating the code to official PyTorch DTensor.2. **Example Usage**: There is an example in TorchTian's SimpleFSDP implementation:pytorch/torchtitan#1060.In the example below, a DTensor `x` is all-gather'ed along the `self.compute_placements`, with datatype cast to `self.param_dtype`. In the backward pass, additionally, the computed gradients are reduce-scatter'ed along the `self.grad_placements`, with datatype cast to `self.reduce_dtype`.```pythonoutput = x.redistribute( placements=self.compute_placements, forward_dtype=self.param_dtype, backward_dtype=self.reduce_dtype,).to_local(grad_placements=self.grad_placements)```Under the hood, in `class Redistribute(torch.autograd.Function):`, the `forward` function first takes `x`'s local tensor, convert it to `forward_dtype`, before all-gather `x`.The `backward` function take `grad_output` and convert it to `backward_dtype`, before reduce-scatter `grad_output`.Pull Requestresolved:pytorch#150740Approved by:https://github.com/tianyu-l
…ble SimpleFSDP mixed precision training (pytorch#150740)As titled, this pr adds additional `forward_dtype` and `backward_dtype` conversion in DTensor `redistribute` API to enable SimpleFSDP's mixed precision training.In this forward pass, the DTensor can be configured to be cast to `forward_dtype`; in the backward pass, the DTensor can be configured to be cast to `backward_dtype`.1. **Correctness**: The end-to-end SimpleFSDP mixed precision training integration has been proved to work properly in the PR from this fork:tianyu-l#20. We are now migrating the code to official PyTorch DTensor.2. **Example Usage**: There is an example in TorchTian's SimpleFSDP implementation:pytorch/torchtitan#1060.In the example below, a DTensor `x` is all-gather'ed along the `self.compute_placements`, with datatype cast to `self.param_dtype`. In the backward pass, additionally, the computed gradients are reduce-scatter'ed along the `self.grad_placements`, with datatype cast to `self.reduce_dtype`.```pythonoutput = x.redistribute( placements=self.compute_placements, forward_dtype=self.param_dtype, backward_dtype=self.reduce_dtype,).to_local(grad_placements=self.grad_placements)```Under the hood, in `class Redistribute(torch.autograd.Function):`, the `forward` function first takes `x`'s local tensor, convert it to `forward_dtype`, before all-gather `x`.The `backward` function take `grad_output` and convert it to `backward_dtype`, before reduce-scatter `grad_output`.Pull Requestresolved:pytorch#150740Approved by:https://github.com/tianyu-l
This is a follow-up on the previous dtensor redistribute PR:#150740, which enables SimpleFSDP's mixed-precision training.In the most recent integration in TorchTitan:pytorch/torchtitan#1250, we found some discrepancies between SimpleFSDP's `fully_shard` and `replicate` modes when MPT is enabled. After debugging, I found the problem is in dtensor redistribute --`local_tensor` is taken out again from the original `input`. Thus, the dtensor used for communication has its original precision instead of using `forward_dtype`.This PR fixes this issue and corrects previously added test cases.After fixing the bug, the loss curves of `fully_shard` and `replicate` mode match perfectly.Pull Requestresolved:#154975Approved by:https://github.com/tianyu-l
This is a follow-up on the previous dtensor redistribute PR:pytorch#150740, which enables SimpleFSDP's mixed-precision training.In the most recent integration in TorchTitan:pytorch/torchtitan#1250, we found some discrepancies between SimpleFSDP's `fully_shard` and `replicate` modes when MPT is enabled. After debugging, I found the problem is in dtensor redistribute --`local_tensor` is taken out again from the original `input`. Thus, the dtensor used for communication has its original precision instead of using `forward_dtype`.This PR fixes this issue and corrects previously added test cases.After fixing the bug, the loss curves of `fully_shard` and `replicate` mode match perfectly.Pull Requestresolved:pytorch#154975Approved by:https://github.com/tianyu-l
This is a follow-up on the previous dtensor redistribute PR:pytorch#150740, which enables SimpleFSDP's mixed-precision training.In the most recent integration in TorchTitan:pytorch/torchtitan#1250, we found some discrepancies between SimpleFSDP's `fully_shard` and `replicate` modes when MPT is enabled. After debugging, I found the problem is in dtensor redistribute --`local_tensor` is taken out again from the original `input`. Thus, the dtensor used for communication has its original precision instead of using `forward_dtype`.This PR fixes this issue and corrects previously added test cases.After fixing the bug, the loss curves of `fully_shard` and `replicate` mode match perfectly.Pull Requestresolved:pytorch#154975Approved by:https://github.com/tianyu-l
This is a follow-up on the previous dtensor redistribute PR:pytorch#150740, which enables SimpleFSDP's mixed-precision training.In the most recent integration in TorchTitan:pytorch/torchtitan#1250, we found some discrepancies between SimpleFSDP's `fully_shard` and `replicate` modes when MPT is enabled. After debugging, I found the problem is in dtensor redistribute --`local_tensor` is taken out again from the original `input`. Thus, the dtensor used for communication has its original precision instead of using `forward_dtype`.This PR fixes this issue and corrects previously added test cases.After fixing the bug, the loss curves of `fully_shard` and `replicate` mode match perfectly.Pull Requestresolved:pytorch#154975Approved by:https://github.com/tianyu-l
Uh oh!
There was an error while loading.Please reload this page.
As titled, this pr adds additional
forward_dtypeandbackward_dtypeconversion in DTensorredistributeAPI to enable SimpleFSDP's mixed precision training.In this forward pass, the DTensor can be configured to be cast to
forward_dtype; in the backward pass, the DTensor can be configured to be cast tobackward_dtype.Correctness: The end-to-end SimpleFSDP mixed precision training integration has been proved to work properly in the PR from this fork:[dtensor] support mixed precision for redistribute tianyu-l/pytorch_intern24#20. We are now migrating the code to official PyTorch DTensor.
Example Usage: There is an example in TorchTian's SimpleFSDP implementation:[SimpleFSDP] Add mixed precision training support torchtitan#1060.
In the example below, a DTensor
xis all-gather'ed along theself.compute_placements, with datatype cast toself.param_dtype. In the backward pass, additionally, the computed gradients are reduce-scatter'ed along theself.grad_placements, with datatype cast toself.reduce_dtype.Under the hood, in
class Redistribute(torch.autograd.Function):, theforwardfunction first takesx's local tensor, convert it toforward_dtype, before all-gatherx.The
backwardfunction takegrad_outputand convert it tobackward_dtype, before reduce-scattergrad_output.cc@H-Huang@awgu@wanchaol@fegin@fduwjj@wz337@wconstab@d4l3k@tianyu-l