- Notifications
You must be signed in to change notification settings - Fork26.3k
[MPS] Make fused rms_norm traceable#150661
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.
Already on GitHub?Sign in to your account
Uh oh!
There was an error while loading.Please reload this page.
Conversation
By declaring it as an ATen opFixes#150629[ghstack-poisoned]
pytorch-botbot commentedApr 4, 2025 • edited
Loading Uh oh!
There was an error while loading.Please reload this page.
edited
Uh oh!
There was an error while loading.Please reload this page.
🔗 Helpful Links🧪 See artifacts and rendered test results athud.pytorch.org/pr/150661
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 32 PendingAs of commitb3dd4b4 with merge base300e0ee ( UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Attention! native_functions.yaml was changedIf you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. Seehttps://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info. Caused by: |
By declaring it as an ATen opFixes#150629[ghstack-poisoned]
kimishpatel commentedApr 4, 2025
Can you add in summary why#150629 introduced a regregssion? Also do we know if inductor generated code will be more performant compared to native kernel? |
Which is a regression, introduced by#150629 (comment) which I should have reviewed more thoroughly.- Defined `_rms_norm_fused`, added MPS-only implementation for it and dispatch from native::rms_norm_symint there in no-grad mode- Defined a decomp for it in `torch/_inductor/decomposition.py`- Added unit test to avoid those regressions in the futureTODO/Ideas: - Perhaps define it as non-decomposable - Make `torch.compiler.is_compiling` reflect to some sort of `at::Context` propertyFixes#150629cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov[ghstack-poisoned]
torch/_inductor/decomposition.py Outdated
| returnaten.leaky_relu(self,negative_slope),torch.Tensor() | ||
| @register_decomposition(aten._rms_norm_fused) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
We don't want this to happen. We want inductor to use the fused implementation, not to decompose it.
| dispatch: | ||
| CompositeImplicitAutograd: rms_norm_symint | ||
| - func: _rms_norm_fused(Tensor input, int normalized_shape_ndim, Tensor weight, float eps) -> Tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
tbh I think you're making your life more complex for no good reason here.
You can make the op above CompositeExplicitAutograd
Uh oh!
There was an error while loading.Please reload this page.
| dispatch: | ||
| CompositeImplicitAutograd: rms_norm_symint | ||
| - func: _rms_norm_fused(Tensor input, int normalized_shape_ndim, Tensor weight, float eps) -> Tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
I am not sure if there is some naming convention, or if it's just coincidence, but when I search the word "fused" in native_functions.yaml, it is always before the name of the function. If there is such naming convention, then this should be changed to_fused_rms_norm
Which is a regression, introduced by#150629 (comment) which I should have reviewed more thoroughly.- Defined `_fused_rms_norm`, added MPS-only implementation for it and dispatch from `rms_norm_symint`, which is registered as `CompositeImplicitAutograd`, i.e. it is not supposed to do any computations over Tensor, only dispatch to other ops- - Register `_fused_rms_norm` as a fallback in `torch/_inductor/lowering.py`- Added unit test to avoid those regressions in the futureTODO:- Get rid of this op, change `rms_norm_symint` definition to `CompositeExplicitAutograd` and implement backward function in `tools/autograd/derivatives.yaml`- Benchmark compiler and re-enable decomp as follows when compiled code is faster```pythonregister_decomposition(aten._rms_norm_fused)def rms_norm_fused( self: torch.Tensor, ndim: int, weight: torch.Tensor, eps: float) -> torch.Tensor: dtr = [self.dim() - i - 1 for i in range(ndim)] return self * weight * (self.pow(2).mean(dtr, keepdim=True).add(eps).rsqrt())```Fixes#150629cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov[ghstack-poisoned]
malfet commentedApr 17, 2025
@pytorchbot merge -f "let's test in prod" |
pytorchmergebot commentedApr 17, 2025
Merge startedYour change will be merged immediately since you used the force (-f) flag,bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in thewiki. Questions? Feedback? Please reach out to thePyTorch DevX Team |
malfet commentedApr 17, 2025
@pytorchbot revert -m *Has decomp started to fail again" -c nosignal |
❌ 🤖 pytorchbot command failed: |
malfet commentedApr 17, 2025
@pytorchbot revert -m "Has decomp started to fail again" -c nosignal |
pytorchmergebot commentedApr 17, 2025
@pytorchbot successfully started a revert job. Check the current statushere. |
This reverts commit682f09e.Reverted#150661 on behalf ofhttps://github.com/malfet due to Has decomp started to fail again ([comment](#150661 (comment)))
pytorchmergebot commentedApr 17, 2025
@malfet your PR has been successfully reverted. |
Which is a regression, introduced by#150629 (comment) which I should have reviewed more thoroughly.- Defined `_fused_rms_norm`, added MPS-only implementation for it and dispatch from `rms_norm_symint`, which is registered as `CompositeImplicitAutograd`, i.e. it is not supposed to do any computations over Tensor, only dispatch to other ops- - Register `_fused_rms_norm` as a fallback in `torch/_inductor/lowering.py`- Added unit test to avoid those regressions in the futureTODO:- Get rid of this op, change `rms_norm_symint` definition to `CompositeExplicitAutograd` and implement backward function in `tools/autograd/derivatives.yaml`- Benchmark compiler and re-enable decomp as follows when compiled code is faster```pythonregister_decomposition(aten._rms_norm_fused)def rms_norm_fused( self: torch.Tensor, ndim: int, weight: torch.Tensor, eps: float) -> torch.Tensor: dtr = [self.dim() - i - 1 for i in range(ndim)] return self * weight * (self.pow(2).mean(dtr, keepdim=True).add(eps).rsqrt())```Fixes#150629cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov[ghstack-poisoned]
malfet commentedApr 17, 2025
@pytorchbot merge -f "Re-added has decomposition" |
pytorchmergebot commentedApr 17, 2025
Merge startedYour change will be merged immediately since you used the force (-f) flag,bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in thewiki. Questions? Feedback? Please reach out to thePyTorch DevX Team |
Uh oh!
There was an error while loading.Please reload this page.
Stack fromghstack (oldest at bottom):
Which is a regression, introduced by#150629 (comment) which I should have reviewed more thoroughly.
_fused_rms_norm, added MPS-only implementation for it and dispatch fromrms_norm_symint, which is registered asCompositeImplicitAutograd, i.e. it is not supposed to do any computations over Tensor, only dispatch to other ops_fused_rms_normas a fallback intorch/_inductor/lowering.pyTODO:
rms_norm_symintdefinition toCompositeExplicitAutogradand implement backward function intools/autograd/derivatives.yamlFixes#150629
cc@voznesenskym@penguinwu@EikanWang@jgong5@Guobing-Chen@XiaobingSuper@zhuhaozhe@blzheng@wenzhe-nrv@jiayisunx@ipiszy@chenyang78@kadeng@muchulee8@amjames@chauhang@aakhundov