- Notifications
You must be signed in to change notification settings - Fork26.3k
Enable FP8 row-wise scaled-mm for sm12x#155991
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.
Already on GitHub?Sign in to your account
Uh oh!
There was an error while loading.Please reload this page.
Conversation
pytorch-botbot commentedJun 14, 2025 • edited
Loading Uh oh!
There was an error while loading.Please reload this page.
edited
Uh oh!
There was an error while loading.Please reload this page.
🔗 Helpful Links🧪 See artifacts and rendered test results athud.pytorch.org/pr/155991
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit46e1250 with merge base517d299 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
alexsamardzic commentedJun 14, 2025
I remember doing a lot of benchmarking, but I'm afraid I don't have results saved... The SM89 configs eventually put there were to improve performance for smaller M; but choosing these configs is a guesswork anyway - this is why auto-tuning is so important. Anyway, is there any particular reason to use SM89 kernel for SM120? There are SM90 and SM100 kernels in the same source file, written using CUTLASS 3.x API (while the SM89 kernel uses 2.x API), and these may be a better match. |
gau-nernst commentedJun 14, 2025
Not exactly. I haven't tried using cutlass 3.x API for this (I can try later). But if the sm89 kernel works, and we can get good perf with tuned configs, it's not exactly necessary to use cutlass 3.x API? |
alexsamardzic commentedJun 14, 2025
I may be wrong, but would expect better performance with 3.x kernel, as it targets archs that are closer to sm120 (TMA etc.); sm89 is really just a weird corner case regarding fp8 support. |
gau-nernst commentedJun 14, 2025 • edited
Loading Uh oh!
There was an error while loading.Please reload this page.
edited
Uh oh!
There was an error while loading.Please reload this page.
I don't think sm120 has TMA? Performant gemm in sm120 is still just cp.async+mma I think. No TMA or tcgen05 like sm90/sm100. Hence I'm expecting fp8 gemm in sm120 to be similar to that in sm89 (if we don't count new dtypes like mxfp8) Edit: my mistake, seems like sm120 has TMAhttps://github.com/NVIDIA/cutlass/blob/main/include/cutlass/gemm/collective/sm120_blockscaled_mma_tma.hpp |
eqy left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
This works without build changes to e.g.,
Line 119 in655b3b1
| "${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/RowwiseScaledMM.cu" |
gau-nernst commentedJun 15, 2025 • edited
Loading Uh oh!
There was an error while loading.Please reload this page.
edited
Uh oh!
There was an error while loading.Please reload this page.
@eqy I probably need to add Locally, I compile pytorch with this command
(maybe the Should I add an entry for sm120 like below as well? Lines 106 to 110 in655b3b1
|
eqy commentedJun 15, 2025
My guess is that you should should be able to test locally with |
| using MainloopScheduleType = cutlass::gemm::collective::KernelScheduleAuto; | ||
| // on sm120, KernelScheduleAuto resolves to KernelTmaWarpSpecializedCooperativeSm120<2>>, | ||
| // which does not support TileShape.M < 128 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
As of CUTLASS 3.9.2 right? May want to specify that.
gau-nernst commentedJun 17, 2025
Can we merge this? Thank you! |
drisspg commentedJun 17, 2025
@pytorchbot merge |
pytorchmergebot commentedJun 17, 2025
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in thewiki. Questions? Feedback? Please reach out to thePyTorch DevX Team |
Uh oh!
There was an error while loading.Please reload this page.
Update using Cutlass 3.x (2025/06/15)
Following@alexsamardzic's advice, I tried out Cutlass 3.x API and it's impressive (rated specs is 419 TFLOPS)
This uses the same SM100 template. The only difference is
<1,1,1>since sm120 does not have multicast featureTile size is fixed toFixed. Use<128,128,128>due to default kernel schedule does not support<64,128,128>. I will work a bit on improve perf for small M.KernelTmaWarpSpecializedPingpongwhen TileShape.M == 64Perf for small M is still bad since it seems like Cutlass does not support TileShape.M < 64 for this kernel. It's possible to boost perf a bit by using TileShape
<64,64,128>.Original using SM89
I tried using cutlass FP8 row-wise scaled-mm for sm89 on sm120 (5090) and it works. I guess it makes sense because sm120 matmul uses the standard sm80 PTX instructions (
cp.async+mmaand friends).Simple benchmark script
Accodring toRTX Blackwell Whitepaper, FP8 MMA with FP32 accumulate is 419 TFLOPS. So the result is quite bad here...
However, if I change
ThreadblockSwizzletocutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>Small M slightly improves, but large M is still bad.
If I further change
ThreadBlockShape=<128, 64, 128>, WarpShape=<64, 32, 128>, NumStages=3for M>256, which is taken fromcutlass example 58, I get the following resultsWhich is much closer to hardware limit. And it also means this kernel is sufficient to get the most perf out of sm120. Only need better tuned configs.
To make sure this high perf is only obtainable with
GemmIdentityThreadblockSwizzle<1>+ThreadBlockShape=<128, 64, 128>, WarpShape=<64, 32, 128>, NumStages=3, I also try usingThreadblockSwizzleStreamK+ThreadBlockShape=<128, 64, 128>, WarpShape=<64, 32, 128>, NumStages=3A bit better than current configs, but still very far away from hardware limit.
@alexsamardzic I noticed you chose this configs in#149978. Do you have any numbers how the current configs perform on sm89?
Update: Using triton codegen-ed from inductor
compiled_scaled_mm = torch.compile(torch._scaled_mm, dynamic=False, mode="max-autotune-no-cudagraphs")Better than default configs, but still far away from the config above for compute-bound
cc@ptrblck@msaroufim@eqy@jerryzh168