Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Enable FP8 row-wise scaled-mm for sm12x#155991

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Closed
gau-nernst wants to merge8 commits intopytorch:mainfromgau-nernst:fp8_sm120

Conversation

@gau-nernst
Copy link
Contributor

@gau-nernstgau-nernst commentedJun 14, 2025
edited by pytorch-botbot
Loading

Update using Cutlass 3.x (2025/06/15)

Following@alexsamardzic's advice, I tried out Cutlass 3.x API and it's impressive (rated specs is 419 TFLOPS)

MNKTFLOPS
164096409617.56
644096409669.63
25640964096266.57
102440964096339.28
409640964096388.91

This uses the same SM100 template. The only difference is

  • Cluster size is fixed to<1,1,1> since sm120 does not have multicast feature
  • Tile size is fixed to<128,128,128> due to default kernel schedule does not support<64,128,128>. I will work a bit on improve perf for small M. Fixed. UseKernelTmaWarpSpecializedPingpong when TileShape.M == 64

Perf for small M is still bad since it seems like Cutlass does not support TileShape.M < 64 for this kernel. It's possible to boost perf a bit by using TileShape<64,64,128>.

Original using SM89

I tried using cutlass FP8 row-wise scaled-mm for sm89 on sm120 (5090) and it works. I guess it makes sense because sm120 matmul uses the standard sm80 PTX instructions (cp.async+mma and friends).

Simple benchmark script

importtorchfromtorch._inductor.utilsimportdo_bench_using_profilingN,K=4096,4096forMin [16,64,256,1024,4096]:A=torch.randn(M,K,device="cuda").to(torch.float8_e4m3fn)B=torch.randn(N,K,device="cuda").to(torch.float8_e4m3fn).Tscale_A=torch.ones(M,1).cuda()scale_B=torch.ones(1,N).cuda()out=torch._scaled_mm(A,B,scale_A,scale_B,out_dtype=torch.bfloat16)out_ref= ((A.float() @B.float())*scale_A*scale_B).bfloat16()torch.testing.assert_close(out,out_ref)latency_us=do_bench_using_profiling(lambda:torch._scaled_mm(A,B,scale_A,scale_B,out_dtype=torch.bfloat16))tflops= (2*M*N*K)/latency_us/1e9print(f"{M=}\t{N=}\t{K=}\t{tflops:.2f} TFLOPS")
MNKTFLOPS
164096409625.73 TFLOPS
644096409671.84 TFLOPS
2564096409686.40 TFLOPS
102440964096112.12 TFLOPS
409640964096121.24 TFLOPS

Accodring toRTX Blackwell Whitepaper, FP8 MMA with FP32 accumulate is 419 TFLOPS. So the result is quite bad here...

However, if I changeThreadblockSwizzle tocutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>

MNKTFLOPS
164096409627.13 TFLOPS
644096409684.84 TFLOPS
2564096409696.75 TFLOPS
102440964096110.21 TFLOPS
409640964096122.98 TFLOPS

Small M slightly improves, but large M is still bad.

If I further changeThreadBlockShape=<128, 64, 128>, WarpShape=<64, 32, 128>, NumStages=3 for M>256, which is taken fromcutlass example 58, I get the following results

MNKTFLOPS
102440964096313.28
409640964096376.73

Which is much closer to hardware limit. And it also means this kernel is sufficient to get the most perf out of sm120. Only need better tuned configs.

To make sure this high perf is only obtainable withGemmIdentityThreadblockSwizzle<1> +ThreadBlockShape=<128, 64, 128>, WarpShape=<64, 32, 128>, NumStages=3, I also try usingThreadblockSwizzleStreamK +ThreadBlockShape=<128, 64, 128>, WarpShape=<64, 32, 128>, NumStages=3

MNKTFLOPS
102440964096144.03
409640964096156.86

A bit better than current configs, but still very far away from hardware limit.

@alexsamardzic I noticed you chose this configs in#149978. Do you have any numbers how the current configs perform on sm89?

Update: Using triton codegen-ed from inductorcompiled_scaled_mm = torch.compile(torch._scaled_mm, dynamic=False, mode="max-autotune-no-cudagraphs")

MNKTFLOPS
164096409625.60
644096409671.74
25640964096161.64
102440964096185.89
409640964096215.53

Better than default configs, but still far away from the config above for compute-bound

cc@ptrblck@msaroufim@eqy@jerryzh168

eqy and Skylion007 reacted with rocket emoji
@pytorch-bot
Copy link

pytorch-botbot commentedJun 14, 2025
edited
Loading

🔗 Helpful Links

🧪 See artifacts and rendered test results athud.pytorch.org/pr/155991

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit46e1250 with merge base517d299 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@alexsamardzic
Copy link
Collaborator

@alexsamardzic I noticed you chose this configs in#149978. Do you have any numbers how the current configs perform on sm89?

I remember doing a lot of benchmarking, but I'm afraid I don't have results saved... The SM89 configs eventually put there were to improve performance for smaller M; but choosing these configs is a guesswork anyway - this is why auto-tuning is so important.

Anyway, is there any particular reason to use SM89 kernel for SM120? There are SM90 and SM100 kernels in the same source file, written using CUTLASS 3.x API (while the SM89 kernel uses 2.x API), and these may be a better match.

gau-nernst and Skylion007 reacted with thumbs up emoji

@gau-nernst
Copy link
ContributorAuthor

Anyway, is there any particular reason to use SM89 kernel for SM120?

Not exactly. I haven't tried using cutlass 3.x API for this (I can try later). But if the sm89 kernel works, and we can get good perf with tuned configs, it's not exactly necessary to use cutlass 3.x API?

@alexsamardzic
Copy link
Collaborator

Not exactly. I haven't tried using cutlass 3.x API for this (I can try later). But if the sm89 kernel works, and we can get good perf with tuned configs, it's not exactly necessary to use cutlass 3.x API?

I may be wrong, but would expect better performance with 3.x kernel, as it targets archs that are closer to sm120 (TMA etc.); sm89 is really just a weird corner case regarding fp8 support.

@gau-nernst
Copy link
ContributorAuthor

gau-nernst commentedJun 14, 2025
edited
Loading

I don't think sm120 has TMA? Performant gemm in sm120 is still just cp.async+mma I think. No TMA or tcgen05 like sm90/sm100. Hence I'm expecting fp8 gemm in sm120 to be similar to that in sm89 (if we don't count new dtypes like mxfp8)

Edit: my mistake, seems like sm120 has TMAhttps://github.com/NVIDIA/cutlass/blob/main/include/cutlass/gemm/collective/sm120_blockscaled_mma_tma.hpp

Skylion007, alexsamardzic, and vgoklani reacted with thumbs up emoji

Copy link
Collaborator

@eqyeqy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

This works without build changes to e.g.,

"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/RowwiseScaledMM.cu"
?

@gau-nernst
Copy link
ContributorAuthor

gau-nernst commentedJun 15, 2025
edited
Loading

@eqy I probably need to add120a to it. Thanks for the check. Anywhere else I should update? And how should I test it locally (without building everything like flash attention) that the build is working as expected?

Locally, I compile pytorch with this command

DEBUG=1 USE_DISTRIBUTED=0 USE_MKLDNN=0 USE_CUDA=1 BUILD_TEST=0 USE_FBGEMM=0 USE_NNPACK=0 USE_QNNPACK=0 USE_XNNPACK=0 USE_FLASH_ATTENTION=0 USE_MEM_EFF_ATTENTION=0 CMAKE_LINKER_TYPE=MOLD TORCH_CUDA_ARCH_LIST="12.0 12.0a" python setup.py develop

(maybe the12.0 was unnecessary, I was messing around with some compile problems)

Should I add an entry for sm120 like below as well?

if("${_arch}"STREQUAL"100a")
if(_existing_arch_flagsMATCHES".*compute_100.*")
list(APPEND _file_compile_flags"-gencode;arch=compute_100a,code=sm_100a")
endif()
endif()

eqy and Skylion007 reacted with thumbs up emoji

@eqy
Copy link
Collaborator

eqy commentedJun 15, 2025

@eqy I probably need to add120a to it. Thanks for the check. Anywhere else I should update? And how should I test it locally (without building everything like flash attention) that the build is working as expected?

Locally, I compile pytorch with this command

DEBUG=1 USE_DISTRIBUTED=0 USE_MKLDNN=0 USE_CUDA=1 BUILD_TEST=0 USE_FBGEMM=0 USE_NNPACK=0 USE_QNNPACK=0 USE_XNNPACK=0 USE_FLASH_ATTENTION=0 USE_MEM_EFF_ATTENTION=0 CMAKE_LINKER_TYPE=MOLD TORCH_CUDA_ARCH_LIST="12.0 12.0a" python setup.py develop

(maybe the12.0 was unnecessary, I was messing around with some compile problems)

My guess is that you should should be able to test locally withTORCH_CUDA_ARCH_LIST=12.0 as the12.0a should be added for just that compilation unit if the CMake config is updated correctly

gau-nernst and Skylion007 reacted with thumbs up emoji


using MainloopScheduleType = cutlass::gemm::collective::KernelScheduleAuto;
// on sm120, KernelScheduleAuto resolves to KernelTmaWarpSpecializedCooperativeSm120<2>>,
// which does not support TileShape.M < 128
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

As of CUTLASS 3.9.2 right? May want to specify that.

gau-nernst reacted with thumbs up emoji
@drisspgdrisspg added module: cudaRelated to torch.cuda, and CUDA support in general release notes: cudarelease notes category labelsJun 16, 2025
@drisspgdrisspg requested a review fromeqyJune 16, 2025 17:39
@drisspgdrisspg added the ciflow/binariesTrigger all binary build and upload jobs on the PR labelJun 16, 2025
@gau-nernst
Copy link
ContributorAuthor

Can we merge this? Thank you!

@drisspg
Copy link
Contributor

@pytorchbot merge

pytorch-bot[bot] and gau-nernst reacted with thumbs up emojigau-nernst reacted with heart emoji

@pytorch-botpytorch-botbot added the ciflow/trunkTrigger trunk jobs on your pull request labelJun 17, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in thewiki.

Questions? Feedback? Please reach out to thePyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

gau-nernst reacted with thumbs up emojigau-nernst reacted with heart emoji

Sign up for freeto join this conversation on GitHub. Already have an account?Sign in to comment

Reviewers

@Skylion007Skylion007Skylion007 left review comments

@eqyeqyeqy approved these changes

@drisspgdrisspgdrisspg approved these changes

@syed-ahmedsyed-ahmedAwaiting requested review from syed-ahmedsyed-ahmed is a code owner

@nWEIdianWEIdiaAwaiting requested review from nWEIdia

@ngimelngimelAwaiting requested review from ngimel

Assignees

No one assigned

Labels

ciflow/binariesTrigger all binary build and upload jobs on the PRciflow/trunkTrigger trunk jobs on your pull requestMergedmodule: cudaRelated to torch.cuda, and CUDA support in generalopen sourcerelease notes: cudarelease notes category

Projects

None yet

Milestone

No milestone

Development

Successfully merging this pull request may close these issues.

7 participants

@gau-nernst@alexsamardzic@eqy@drisspg@pytorchmergebot@Skylion007@pytorchbot

[8]ページ先頭

©2009-2025 Movatter.jp