Uh oh!
There was an error while loading.Please reload this page.
- Notifications
You must be signed in to change notification settings - Fork876
📚LeetCUDA: Modern CUDA Learn Notes with PyTorch for Beginners🐑, 200+ CUDA Kernels, Tensor Cores, HGEMM, FA-2 MMA.🎉
License
xlite-dev/LeetCUDA
Folders and files
| Name | Name | Last commit message | Last commit date | |
|---|---|---|---|---|
Repository files navigation
📚LeetCUDA: It includesTensor/CUDA Cores, TF32/F16/BF16/F8,📖200+ CUDA Kernels🔥 with PyTorch,📖100+ LLM/CUDA🔥 blogs,📖HGEMM⚡️ which can achieve98%~100% TFLOPS ofcuBLAS, and📖flash-attn⚡️ using Tensor Cores with pure MMA PTX.
@misc{LeetCUDA@2025,title={LeetCUDA: A Modern CUDA Learn Notes with PyTorch for Beginners},url={https://github.com/xlite-dev/LeetCUDA.git},note={Open-source software available at https://github.com/xlite-dev/LeetCUDA.git},author={DefTruth and Many Others},year={2025}}
- [2025-08-18]:🤗cache-dit is released! 🤗A PyTorch-native Inference Engine with Hybrid Cache Acceleration and Parallelism for DiTs. Feel free to take a try!
- [2025-01-08]:🤖ffpa-attn is released! Yet another Faster Flash Prefill Attention with O(1)🎉SRAM complexity for large headdim,1.8x~3x↑🎉 vs SDPA EA:📈L20 ~1.9x↑🎉,📈A30 ~1.8x↑🎉,📈4090 ~2.1x↑🎉.
- [2024-12-02]:⚡️HGEMM is released! Write HGEMM from scratch using Tensor Cores withWMMA, MMA and CuTe API, achieve peak🎉 performance.
- 📖 HGEMM-MMA 🎉🎉
- 📖 FlashAttention-MMA 🎉🎉
- 📖 200+ CUDA Kernels 🔥🔥
- 📖 100+ LLM/CUDA Blogs 🔥
- 📖 How to Contribute 👀👇
Currently, on NVIDIA L20, RTX 4090 and RTX 3080 Laptop, compared with cuBLAS's default Tensor Cores algorithm, theHGEMM (WMMA/MMA/CuTe) in this repo (blue🔵) can achieve98%~100% of its (orange🟠) performance. Please checktoy-hgemm library⚡️⚡️ orHGEMM⚡️⚡️ repo for more details.
| 📚Feature | 📚Feature | 📚Feature | 📚Feature |
|---|---|---|---|
| ✔️CUDA/Tensor Cores | ✔️Loop over K | ✔️Tile Block(BMxBK) | ✔️Tile Threads(T 8x8) |
| ✔️WMMA(m16n16k16) | ✔️MMA(m16n8k16) | ✔️Pack LDST(128 bits) | ✔️SMEM Padding |
| ✔️Copy Async | ✔️Tile MMAs | ✔️Tile Warps | ✔️Multi Stages(2~4) |
| ✔️Register Double Buffers | ✔️Block Swizzle | ✔️Warp Swizzle | ✔️SMEM Swizzle(CuTe/MMA) |
| ✔️Collective Store(Shfl) | ✔️Layout NN | ✔️Layout TN | ✔️SGEMM FP32/TF32 |
I have also implementedFlashAttention-2 using pure MMA PTX instructions, which supports features such as Multi-Stages, Tile MMA, Tile Warp, Shared KV SMEM,Fully Shared QKV SMEM,Prefetch Q s2r,Prefetch K/V g2s,QKV Fine-grained Tiling, Collective Store, etc. Please refer toflash-attn⚡️⚡️ for more details.
| 📚Feature | 📚Feature | 📚Feature | 📚Feature |
|---|---|---|---|
| ✔️Tensor Cores | ✔️Loop over N/D | ✔️Tile Block(Br, Bc) | ✔️MMA(m16n8k16) |
| ✔️Pack LDST(128 bits) | ✔️SMEMSwizzle/Padding | ✔️Copy Async | ✔️Tile MMAs |
| ✔️Tile Warps | ✔️Multi Stages(1/2) | ✔️Collective Store(Shfl) | ✔️Split KV/Q |
| ✔️Shared QKV SMEM | ✔️Prefetch Q s2r | ✔️Prefetch KV g2s | ✔️QKV Fine-grained Tiling |
Currently, for small-scale attention(B<=4, H <=48, SeqLen <= 8192, D <= 64) it can run faster than FA2/SDPA on some Devices. For example, on NVIDIA RTX 3080 Laptop,📚 Split Q + Fully Shared QKV SMEM method can achieve55 TFLOPS (D=64) that almost~1.5x 🎉 faster than FA2. On NVIDIA L20, 🤖ffpa-attn method can achieve104 TFLOPS (D=512) that almost~1.8x 🎉 faster than SDPA (EFFICIENT ATTENTION). However, for large-scale attention, there remains a performance gap. Stay tuned for updates ~ (MMA Acc F16/F32, softmax Acc F32 vs FA2 MMA/softmax Acc F32, 👇Benchmark)
| Algorithm | (B,H,N,D) | RTX 3080 Laptop | L20 | RTX 4090 |
|---|---|---|---|---|
| FlashAttention-2 | (1,8,8192,64) | 37 TFLOPS | 100 TFLOPS | 145 TFLOPS |
| share-qkv+stage2 | (1,8,8192,64) | 55 TFLOPS | 99 TFLOPS | 221 TFLOPS |
| FlashAttention-2 | (1,48,8192,64) | 37 TFLOPS | 109 TFLOPS | 163 TFLOPS |
| share-qkv+stage2 | (1,48,8192,64) | 48 TFLOPS | 107 TFLOPS | 224 TFLOPS |
| SDPA(EFFICIENT ATTENTION) | (1,48,8192,512) | 16 TFLOPS | 58 TFLOPS | 85 TFLOPS |
| 🤖ffpa-attn | (1,48,8192,512) | 39 TFLOPS | 104 TFLOPS | 200 TFLOPS |
| Precision Errors vs FA2/SDPA | / | max: < ~1e-3 | min: ~0.0 | mean: < ~1e-5 |
TheSplit KV andSplit Q implementations have been carried out inflash-attn⚡️⚡️ for performance comparison. TheSplit KV method, which involves splitting all QKV across MMA (Warps), is slower thanSplit Q method, which splitting Q across MMA(Warps) and keep access KV for all MMA(Warps).
- 📚 Split KV (Basic, FlashAttention-1)
// Split QKV across MMA(Warps) using naive matmul MMA&Warp tiling policy.// case: The layout of 8 MMA(2x4) [after] kWarpTileSeqLenQxkWarpTileSeqLenK(2x2) -> 32x2,32x2=64x64:// | [64,64] | warp_KV 0 | warp_KV 1 | warp_KV 2 | warp_KV 3 |// | warp_QP 0 |-- MMA 0,MMA 0 --|-- MMA 2,MMA 2 --|-- MMA 4,MMA 4 --|-- MMA 6,MMA 6 --|// | warp_QP 0 |-- MMA 0,MMA 0 --|-- MMA 2,MMA 2 --|-- MMA 4,MMA 4 --|-- MMA 6,MMA 6 --|// | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 2 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --|// | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 2 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --|__global__void// Q, K, V, O -> [B, H, N, D]flash_attn_mma_stages_split_kv_kernel(half* Q, half* K, half* V, half* O, ...);
- 📚 Split Q (Faster, FlashAttention-2)
// Split Q across MMA(Warps) and keep access KV for all MMA(Warps),// in order to reduce the comm between warps via smem and warp shuffle.// case: MMA = m16n8k16, Br=16x4=64, Bc=8x8=64, layout: 4 warps// | 64x64 | warp_KV 0 |// | warp_QP 0 | MMA 0 ... MMA 0 (x8) |// | warp_QP 1 | MMA 1 ... MMA 1 (x8) |// | warp_QP 2 | MMA 2 ... MMA 2 (x8) |// | warp_QP 3 | MMA 3 ... MMA 3 (x8) |__global__void// Q, K, V, O -> [B, H, N, D]flash_attn_mma_stages_split_q_kernel(half* Q, half* K, half* V, half* O, ...);
- 📚 Split Q + Shared KV SMEM (1/2 SRAM vs FA2)
// K, V shared the same shared memory, improve block occupancy.__global__void// Q, K, V, O -> [B, H, N, D]flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, half* K, half* V, half* O, ...);
- 📚 Split Q + Fully Shared QKV SMEM (1/4 SRAM vs FA2)
// Q, K, V fully shared the same shared memory and prefetch Q s2r, improve block occupancy// and reduce Q SMEM IO-Access.__global__void// Q, K, V, O -> [B, H, N, D]flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, half* K, half* V, half* O, ...);
- 📚 Split Q + QK Fine-grained Tiling (O(16xd) SRAM vs FA2O(4xBrxd) SRAM,
Headdim -> 1024)
// Fine-grained tiling at the MMA level for Q@K^T results in a constant SRAM usage of// 64 * kMmaAtomK for Q and K. For V, the SRAM complexity is O(kMmaAtomK * d), leading to// an overall SRAM complexity of O(kMmaAtomK * d). Consequently, this approach allows us to// extend D (head dimension) up to 1024.__global__void// Q, K, V, O -> [B, H, N, D]flash_attn_mma_stages_split_q_tiling_qk_kernel(half* Q, half* K, half* V, half* O, ...);
- 📚 Split Q + Fully QKV Fine-grained Tiling (O(2xBrx16)~O(1) SRAM vs FA2O(4xBrxd) SRAM)
// Fine-grained tiling at the MMA level for all Q@K^T and P@V results in a constant SRAM usage of// Br * 16 or Bc * 16 for Q, K, V, leading to an overall SRAM complexity of O(Br * 16). Consequently,// this approach allows us to run faster than SDPA w or w/o MMA Acc F32.__global__void// Q, K, V, O -> [B, H, N, D]flash_attn_mma_stages_split_q_tiling_qkv_kernel(half* Q, half* K, half* V, half* O, ...);
💡NOTE:📚Split Q + Fully QKV Fine-grained Tiling has been refactored into 🤖ffpa-attn.
📖 200+ CUDA Kernels 🔥🔥 (Easy -> Hard++) (©️back👆🏻)
The kernels listed here will guide you through a step-by-step progression, ranging from easy to very challenging topics. Theworkflow for each topic will be as follows: customCUDA kernel implementation -> PyTorchPython bindings -> Run tests. 👉TIPS:* = Tensor Cores (WMMA, MMA, CuTe), otherwise, CUDA Cores;/ = not supported;✔️ = supported;❔ = TODO. Contents are listed as follows:
📚 Easy and📚 Medium sections cover operations such aselement-wise, mat_trans, warp/block reduce, nms, relu, gelu, swish, layer-norm, rms-norm, online-softmax, dot-prod, embedding and basic usage forFP32,FP16,BF16 andFP8 .📚 Hard,📚 Hard+ and📚 Hard++ sections delve deeper into advanced topics, primarily focusing on operations likesgemv, sgemm, hgemv, hgemm and flash-attention. These sections also provide numerous kernels implemented using Tensor Cores with pure MMA PTX.
📚 Easy ⭐️ & Medium ⭐️⭐️ (©️back👆🏻)
📚 Hard ⭐⭐⭐️ (©️back👆🏻)
📚 Hard+ ⭐️⭐️⭐️⭐️ & Hard++ ⭐️⭐️⭐️⭐️⭐️ (©️back👆🏻)
- 📚 FlashAttention-2 MMA (MMA Acc F32/F16, swizzle, QKV smem share, fine-grained tiling, etc.🎉)
💡NOTE:rr: means reduce registers usage (ford>128);f32: means MMA accumulate with FP32 dtype, otherwise, FP16. softmax Acc dtype is always be FP32 for high precision;swizzle: now, only support smem swizzle for MMA.
- 📚 FFPA Attention MMA (1.8x~3x🎉faster vs SDPA EA, D > 256, FA2 not supported)
| 📖 CUDA Kernel | 📖 Elem DType | 📖 Acc DType | 📖 Docs | 📖 Level |
|---|---|---|---|---|
| ✔️ffpa_mma_stages_split_q_L1_F16F16F16 | f16 | f16 | link | ⭐️⭐️⭐️⭐️ |
| ✔️ffpa_mma_stages_split_q_L1_F16F16F32 | f16 | f32 | link | ⭐️⭐️⭐️⭐️ |
| ✔️ffpa_mma_stages_split_q_L1_mixed_acc | f16 | QK f32, PV f16 | link | ⭐️⭐️⭐️⭐️ |
| f16 | f16 | link | ⭐️⭐️⭐️⭐️ | |
| f16 | f32 | link | ⭐️⭐️⭐️⭐️ | |
| f16 | QK f32, PV f16 | link | ⭐️⭐️⭐️⭐️ | |
| f16 | f16 | link | ⭐️⭐️⭐️⭐️ | |
| f16 | f32 | link | ⭐️⭐️⭐️⭐️ | |
| f16 | QK f32, PV f16 | link | ⭐️⭐️⭐️⭐️ |
💡NOTE: 🤖ffpa-attn: 📚FFPA - Yet another Faster Flash Prefill Attention with O(1)🎉SRAM complexity for headdim > 256,1.8x~3x🎉faster than SDPA EA:📈L20 ~1.9x↑🎉,📈 A30 ~1.8x↑🎉,📈3080 ~2.9x↑🎉,📈4090 ~2.1x↑🎉.
📚 Triton Kernel (OpenAI Triton) ⭐️⭐️⭐️ (©️back👆🏻)
| 📖 Triton Kernel | 📖 Elem DType | 📖 Acc DType | 📖 Docs | 📖 Level |
|---|---|---|---|---|
| ✔️triton_vector_add_kernel | all | all | link | ⭐️⭐️ |
| ✔️triton_fused_softmax(multi-stages) | f16/bf16/f32 | f32 | link | ⭐️⭐️⭐️ |
| ✔️triton_fused_layer_norm(forward-pass) | f16/bf16/f32 | f32 | link | ⭐️⭐️⭐️ |
| ✔️triton_fused_layer_norm(backward-pass) | f16/bf16/f32 | f32 | link | ⭐️⭐️⭐️ |
| ✔️triton_merge_attn_states_kernel(w/ CUDA) | f16/bf16/f32 | f32 | link | ⭐️⭐️⭐️ |
📚 CUTLASS/CuTe Kernel ⭐️⭐️⭐️ (©️back👆🏻)
| 📖 CUTLASS/CuTe Kernel | 📖 Elem DType | 📖 Acc DType | 📖 Docs | 📖 Level |
|---|---|---|---|---|
| ✔️mat_transpose_cute | f32 | / | link | ⭐️⭐️ |
| ✔️flash_attn_cute(naive) | f16 | f32 | link | ⭐️⭐️⭐️ |
| ✔️hgemv_f16_cute_kernel | f16 | f16 | link | ⭐️⭐️⭐️ |
| ✔️hgemv_f16x8_cute_kernel | f16 | f16 | link | ⭐️⭐️⭐️ |
| ✔️hgemv_tensor_core_cute_kernel | f16 | f16 | link | ⭐️⭐️⭐️ |
| ✔️hgemm_mma_stages_swizzle{smem}...cute* | f16 | f16 | link | ⭐️⭐️⭐️ |
| ✔️ws_hgemm_naive_cute_kernel | f16 | f16 | link | ⭐️⭐️⭐️ |
📚 高性能计算与分布式-个人技术专栏 (©️back👆🏻)
📚 高性能计算与分布式-技术博客推荐 (©️back👆🏻)
💡说明: 本小节整理一些自己比较喜欢的文章。欢迎大家提PR推荐更多优秀的文章!
©️License (©️back👆🏻)
GNU General Public License v3.0
🎉Contribute (©️back👆🏻)
How to contribute? Star this repo or check🌤🌤CONTRIBUTE🎉🎉.
📖 References (©️back👆🏻)
About
📚LeetCUDA: Modern CUDA Learn Notes with PyTorch for Beginners🐑, 200+ CUDA Kernels, Tensor Cores, HGEMM, FA-2 MMA.🎉
Topics
Resources
License
Uh oh!
There was an error while loading.Please reload this page.
Stars
Watchers
Forks
Sponsor this project
Uh oh!
There was an error while loading.Please reload this page.
Packages0
Uh oh!
There was an error while loading.Please reload this page.




