Writing high-performance matrix multiplication kernels for Blackwell
Contents
Writing high-performance matrix multiplication kernels for Blackwell#
In this guide, we’ll progressively iterate on a matrix multiplication kernel.The first implementation will be very simple, but also quite slow.However, in just a few simple steps it can be modified into a state-of-the-artkernel, matching or exceeding highly optimized implementations such as cuBLASand CUTLASS.
Warning
The utilization shown in the table below might be different than what you see online,but the differences can likely be explained by a different input data distribution.All our benchmarks here use arrays with iid normal float16 entries, which turn outto be one of the slower distributions you can choose. You can reproducethe numbers for yourself by runningour test file after changing theBENCHMARK variable toTrue.
tl;dr don’t believe matmul benchmarks if they don’t specify input data distribution.
Implementation | TensorCore utilization | % of cuBLAS utilization |
|---|---|---|
0. Basic kernel | 37.62% | 59.4% |
1. Warp specialization | 45.47% | 71.7% |
2. Tiled epilogue | 55.82% | 88.1% |
3. Collective (2CTA) MMA | 59.41% | 93.7% |
4. Persistent kernel | 61.46% | 97.0% |
5. Dedicated epilogue warpgroup | 63.38% | 100.0% |
6. Grid tiling | 69.44% | 109.6% |
cuBLAS | 63.38% | 100.0% |
CUTLASS | 69.30% | 109.3% |
The cuBLAS baseline is obtained by measuring the performace ofjax.dot. TheCUTLASS performance is measured by taking the best result from the followingcutlass_profiler invocation (excluding sparse matmuls):
cutlass_profiler--dist=gaussian,mean:0,stddev:1,scale:-1--output=results.csv--accumulator-type=f32--m=4096--k=4096--n=8192--kernels='*sm100*'--A=f16--B=f16--C=void--D=f16
At each step, we will showcase either the full implementation of the kernel, orthe difference between the code listings shown in the previous and current steps.Full implementations can be found inour test file. You can also find the a full standaloneoptimized kernel implementationin the Pallas ops package.
0. Basic kernel#
We begin with a simple single-CTA (block) single-warpgroup example.For convenience, we split the tuning parameters of the kernel into a separateclass:
@dataclasses.dataclass(frozen=True)classTuningConfig:tile_m:inttile_n:inttile_k:intmax_concurrent_steps:int
tile_m,tile_n andtile_k specify the size of the matmul performed atevery step of the pipeline. In general,tile_k should ideally be equal to128 divided by the bytewidth of the input element type.max_concurrent_stepsspecifies the depth of memory prefetch in the compute/memory pipeline, which isfrequently called the number of stages in other implementations.
The kernel implementation begins with a bit of setup:
defmatmul0(a,b,config:TuningConfig):dtype=a.dtypem,k=a.shape_,n=b.shapetile_m,tile_n,tile_k=config.tile_m,config.tile_n,config.tile_kswizzle=plgpu.find_swizzle(tile_k*jnp.dtype(dtype).itemsize*8)swizzle_elems=swizzle//jnp.dtype(dtype).itemsizetransforms=(plgpu.TilingTransform((8,swizzle_elems)),plgpu.SwizzleTransform(swizzle))ifm%tile_m!=0:raiseValueError(f"{m=} must be divisible by{tile_m=}")ifn%tile_n!=0:raiseValueError(f"{n=} must be divisible by{tile_n=}")ifk%tile_k!=0:raiseValueError(f"{k=} must be divisible by{tile_k=}")m_iters=m//tile_mn_iters=n//tile_nk_iters=k//tile_kmax_concurrent_steps=config.max_concurrent_steps
We unpack the config variables for easier access, set the tiling and swizzlingtransforms to get the SMEM data format to matchwhat’s expected by MMA instructions.
The kernel implementation itself is relatively short. The first part sets upacompute/memory pipeline usingplgpu.emit_pipeline. At each step, the compute function (do_mma) consumes a(tile_m,tile_k) slice of LHS and(tile_k,tile_n) slice of RHS. As mentionedbefore, we specifytransforms, as welldelay_release=1. This last parameterensures that the input windows (a_smem,b_smem) passed intodo_mma will notbe overwritten at least until the next invocation ofdo_mma completes. This isnecessary because we only await the completion of the MMA from the one stepin the following step, which is whyarrive_barrier_slot andwait_barrier_slotflip between 0 and 1 at each invocation.
defkernel(a_gmem,b_gmem,out_gmem,acc_tmem,acc_smem,consumed_barriers):mi=lax.axis_index("m")ni=lax.axis_index("n")m_slice=pl.ds(mi*tile_m,tile_m)n_slice=pl.ds(ni*tile_n,tile_n)defdo_mma(idxs,a_smem,b_smem):(ki,)=idxsarrive_barrier_slot=ki%2wait_barrier_slot=1-arrive_barrier_slotplgpu.tcgen05_mma(acc_tmem,a_smem,b_smem,barrier=consumed_barriers.at[arrive_barrier_slot],accumulate=(ki>0),)plgpu.barrier_wait(consumed_barriers.at[wait_barrier_slot])# Make sure the wait succeeds in the first iteration.plgpu.barrier_arrive(consumed_barriers.at[1])block_kwargs=dict(transforms=transforms,delay_release=1)plgpu.emit_pipeline(do_mma,in_specs=[plgpu.BlockSpec((tile_m,tile_k),lambdaki:(mi,ki),**block_kwargs),plgpu.BlockSpec((tile_k,tile_n),lambdaki:(ki,ni),**block_kwargs),],grid=(k_iters,),max_concurrent_steps=max_concurrent_steps,)(a_gmem,b_gmem)
The kernel itself ends with an epilogue. We await the completion of the last MMAissued by the pipeline before doing anything. Then, we load the final accumulatorfrom TMEM, write it to SMEM (remembering to callplgpu.commit_smem),and copy it back to GMEM using TMA.
defkernel(...):...# compute pipeline as abovefinal_barrier=1-(k_iters%2)plgpu.barrier_wait(consumed_barriers.at[final_barrier])acc_smem[...]=plgpu.async_load_tmem(acc_tmem).astype(dtype)plgpu.commit_smem()plgpu.copy_smem_to_gmem(acc_smem,out_gmem.at[m_slice,n_slice])plgpu.wait_smem_to_gmem(0,wait_read_only=True)
What remains is to actually turn the kernel into a function that can be calledwith JAX arrays. We useplgpu.kernelfor that. The grid is for now simply 2D and iterates over the output tiles. Weallocate intermediates used by the kernel:
The TMEM buffer used as an accumulator
The SMEM buffer used to stage the accumulator before its copy to GMEM
The barrier used to await the completion of MMA operations.
defmatmul0(a,b,config):...# Setup code from the first snippetdefkernel(...):...# The whole kernel bodyf=plgpu.kernel(kernel,out_shape=jax.ShapeDtypeStruct((m,n),dtype),grid=(m_iters,n_iters),grid_names=("m","n"),scratch_shapes=dict(acc_tmem=plgpu.TMEM((tile_m,tile_n),jnp.float32),acc_smem=plgpu.SMEM((tile_m,tile_n),dtype,transforms=transforms),consumed_barriers=plgpu.Barrier(num_arrivals=1,num_barriers=2,orders_tensor_core=True),))returnf(a,b)
Omitting the setup code, that’s just 50 lines! Unfortunately, it’s not veryfast just yet, but it does achieve half the utilization of cuBLAS already!
1. Warp specialization#
Note
Recall that on Blackwell a single Pallas:MGPU thread of execution corresponds toa warpgroup of CUDA lanes/threads.
The kernel above uses a single warpgroup to do everything: from fetching the data,through issuing MMA operations, to storing the results into GMEM. While one wouldthink that the asynchronicity in TensorCore execution should allow us to overlapthe overheads of async copies (TMA) and control-flow, it does not seem to be thecase.
A common solution to this problem in the Hopper generation of GPUs was to utilizewarpgroup specialization. In Pallas terms,plgpu.kernel can be called withnum_threads=2, meaning that each program in the grid would result in two callsto the body. The thread index is then often queried usinglax.axis_index andused to select one of multiple different roles, such asonly issuing asynccopies oronly running the MMA operations.
This solution also works in the Blackwell generation, but it is in fact evensimpler. Since both the async copy (TMA) as well as thetcgen05 MMA instructiononly require a single CUDA lane to issue them,we don’t even need to use multiplewarpgroups. We can simply break up a singlewarpgroup intofour warps and specialize those!
In Pallas, this can be achieved usingpl.core_map with aplgpu.WarpMesh.For each Pallas thread that calls such acore_map, the body will be invokedexactly four times. Thecore_map synchronizes all warps both at entry at exit.Note that only scalar operations are allowed in the body.
This will be the biggest rewrite to this kernel we’ll perform in this wholesequence, which is why we’ll list the entire kernel source once again.
defmatmul1(a,b,config:TuningConfig):...# Setup code remains unmodifieddefkernel(a_gmem,b_gmem,out_gmem,a_smem,b_smem,acc_tmem,acc_smem,load_barriers,consumed_barriers,mma_done_barrier):m_index=lax.axis_index("m")n_index=lax.axis_index("n")m_slice=pl.ds(m_index*tile_m,tile_m)n_slice=pl.ds(n_index*tile_n,tile_n)@pl.core_map(plgpu.WarpMesh(axis_name="warp"))def_per_warp():warp_id=lax.axis_index("warp")@pl.when(warp_id==0)def_memory():def_loop_body(ki,_):slot=lax.rem(ki,max_concurrent_steps)@pl.when(ki>=max_concurrent_steps)def_():# Make sure the data has been consumed before overwriting.plgpu.barrier_wait(consumed_barriers.at[slot])k_slice=pl.ds(ki*tile_k,tile_k)plgpu.copy_gmem_to_smem(a_gmem.at[m_slice,k_slice],a_smem.at[slot],load_barriers.at[slot])plgpu.copy_gmem_to_smem(b_gmem.at[k_slice,n_slice],b_smem.at[slot],load_barriers.at[slot])lax.fori_loop(0,k_iters,_loop_body,None)@pl.when(warp_id==1)def_compute():def_loop_body(ki,_):slot=lax.rem(ki,max_concurrent_steps)plgpu.barrier_wait(load_barriers.at[slot])# Wait for data to arrive.plgpu.tcgen05_mma(acc_tmem,a_smem.at[slot],b_smem.at[slot],consumed_barriers.at[slot],accumulate=(ki>0),)lax.fori_loop(0,k_iters,_loop_body,None)plgpu.tcgen05_commit_arrive(mma_done_barrier)plgpu.barrier_wait(mma_done_barrier)acc_smem[...]=plgpu.async_load_tmem(acc_tmem).astype(dtype)plgpu.commit_smem()plgpu.copy_smem_to_gmem(acc_smem,out_gmem.at[m_slice,n_slice])plgpu.wait_smem_to_gmem(0,wait_read_only=True)
The kernel has exactly the same structure as before: we first perform the compute,which is followed by the epilogue. The epilogue remains the same (we only use adifferent barrier to await the completion), so we will not discuss it further.
Theplgpu.emit_pipeline call and thedo_mma function has been replaced bya singlepl.core_map invocation. You can see that immediately after enteringits body, each Pallas thread (now representing a warp!) finds out which of thefour threads it is. We then use thread with index 0 toonly issue asynccopies that fetch the MMA operands in a loop, while thread with index 1 entersanother loop in which it repeatedly callsplgpu.tcgen05_mma.
One interesting aspect here is the synchronization. We keep an array ofload_barriers, each tracking progress of an outstanding GMEM->SMEM copy.The compute thread must await their completion before feeding the respectiveoperands to the MMA operation. Going in the other direction, the thread responsiblefor async copies must await the completion of MMAs that consume operands beforeit can overwrite the memory by issuing another async copy. This is tracked throughconsumed_barriers. Finally, when the compute thread is done issuing all MMAoperations, it callsplgpu.tcgen05_commit_arrive(mma_done_barrier), requestingthe TensorCore to complete themma_done_barrier once all the MMA operations complete.
We can now turn our attention to theplgpu.kernel definition. The only differenceto the prior version is that we explicitly allocate two additional SMEM buffersthat hold the MMA operands (previously they were implicitly allocated byplgpu.emit_pipeline), as well as the additional barriers. Note that theload_barriers havenum_arrivals=2, because we issue two async copies on thesame barrier.orders_tensor_core is necessary to specify on barriers that aremeant to indicate the completion of TensorCore operations.
defmatmul1(a,b,config:TuningConfig):...# Setup code remains unmodifieddefkernel(...):...# Kernel code abovef=plgpu.kernel(kernel,...,# Other parameters remain unchangedscratch_shapes=dict(a_smem=plgpu.SMEM((max_concurrent_steps,tile_m,tile_k),dtype,transforms=transforms),b_smem=plgpu.SMEM((max_concurrent_steps,tile_k,tile_n),dtype,transforms=transforms),acc_tmem=plgpu.TMEM((tile_m,tile_n),jnp.float32),acc_smem=plgpu.SMEM((tile_m,tile_n),dtype,transforms=transforms),load_barriers=plgpu.Barrier(num_arrivals=2,num_barriers=max_concurrent_steps),consumed_barriers=plgpu.Barrier(num_arrivals=1,num_barriers=max_concurrent_steps,orders_tensor_core=True,),mma_done_barrier=plgpu.Barrier(num_arrivals=1,num_barriers=1,orders_tensor_core=True),))returnf(a,b)
This relatively simple modification already gives us a meaningful bump in performance,getting us up to almost 70% of cuBLAS performance.
2. Tiled epilogue#
This time, we turn our attention away from the compute portion of the kernel andinstead focus on its epilogue. We can improve its efficiency by pipeliningthe copy from TMEM to SMEM together with a transfer from SMEM to GMEM. To do this,we change ourscratch_shapes to allocate two smaller buffers instead of anSMEM window that can hold the entire output (which also decreases our SMEM usage):
defmatmul2(a,b,config):...# Setup and kernel codef=plgpu.kernel(...scratch_shapes=dict(...# Previously: plgpu.SMEM((tile_m, tile_n), dtype, transforms=transforms),acc_smem=plgpu.SMEM((2,tile_m,config.epilogue_tile_n),dtype,transforms=transforms),...))
Then, in the kernel, we loop over the output columns in chunks ofepilogue_tile_n,and progressively send out the output to GMEM:
defmatmul2(a,b,config):...# Setup code remains unchangeddefkernel(...):...# Compute part remains unchangedplgpu.barrier_wait(mma_done_barrier)out_gmem_window=out_gmem.at[m_slice,n_slice]forniinrange(tile_n//config.epilogue_tile_n):acc_smem_ni=acc_smem.at[ni%2]ni_slice=pl.ds(ni*config.epilogue_tile_n,config.epilogue_tile_n)# Make sure that previous copy is done before we overwrite.plgpu.wait_smem_to_gmem(1,wait_read_only=True)acc_smem_ni[...]=plgpu.async_load_tmem(acc_tmem.at[:,ni_slice]).astype(dtype)plgpu.commit_smem()plgpu.copy_smem_to_gmem(acc_smem_ni,out_gmem_window.at[:,ni_slice])plgpu.wait_smem_to_gmem(0,wait_read_only=True)
3. Collective (2CTA) MMA#
If you benchmark our latest kernel, you’ll quickly find out that it can’t utilizeits compute units well, because they are constantly waiting on the memory to deliverthe MMA operands. This means that our kernel is memory bound, because it has toolowarithmetic intensity: the number of flops we perform for each byte we loadis too small.
One very effective trick of the Blackwell architecture that allows us to doubleour arithmetic intensity arecollective MMAs.The core idea is quite simple: we use a cluster of two blocks (on two SMs) tocompute a single matmul. Each block only loads half of each operand, but the MMAoperation exchanges the data from SMEM of each block as its running.
We’ll start with the kernel configuration changes again:
defmatmul3(a,b,config):...# Setup codecluster_tile_m=2*tile_mcluster_tile_n=2*tile_nm_iters=m//cluster_tile_mn_iters=n//cluster_tile_n...# Setup code and kernelf=plgpu.kernel(...grid=(m_iters,n_iters),...cluster=(2,),cluster_names=("cluster",),scratch_shapes=dict(...# Previously: plgpu.TMEM((tile_m, tile_n), jnp.float32),acc_tmem=plgpu.TMEM((tile_m,cluster_tile_n),jnp.float32,collective=True),...))
We add thecluster parameter toplgpu.kernel to indicate that we intend tohave pairs of programs collaborate (as CUDA block clusters). We also appendcollective=True to our TMEM allocation, to ensure that it will be allowed tobe used by collective MMAs and double its number of columns (tocluster_tile_n).
Another notable change is that our pair of blocks will ultimately compute a4x larger output tile, which is why we shrink the grid correspondingly.
We first update the entry of the kernel:
defkernel(...):is_lead_block=lax.axis_index("cluster")==0m_index=lax.axis_index("m")n_index=lax.axis_index("n")m_slice=pl.ds(m_index*cluster_tile_m,cluster_tile_m)n_slice=pl.ds(n_index*cluster_tile_n,cluster_tile_n)
The only changes here are that we usecluster_tile_m andcluster_tile_n tocompute the slice of the output the two blocks will collectively compute, andwe also check if the current invocation corresponds to the first (leader) blockin the cluster. This is important, becauseonly the leader block is supposed toissue MMA instructions:
@pl.core_map(plgpu.WarpMesh(axis_name="warp"))def_per_warp():warp_id=lax.axis_index("warp")@pl.when(warp_id==0)def_memory():def_loop_body(ki,_):...# Wait for the data to be consumed, as previously.plgpu.copy_gmem_to_smem(...,collective_axes="cluster",partitioned_axis=0)plgpu.copy_gmem_to_smem(...,collective_axes="cluster",partitioned_axis=1)lax.fori_loop(0,k_iters,_loop_body,None)@pl.when(jnp.logical_and(warp_id==1,is_lead_block))def_compute():def_loop_body(ki,_):...# Wait for the data to arrive, as previously.plgpu.tcgen05_mma(...,collective_axis="cluster",)lax.fori_loop(0,k_iters,_loop_body,None)plgpu.tcgen05_commit_arrive(mma_done_barrier,collective_axis="cluster")
You can see a few modifications here. First of all, both blocks must issue theasync copies. In both blocks we request a copy of the full window for the wholecluster, but the addition ofcollective_axes="cluster" indicates that the loadis performed jointly by both blocks.partitioned_axis= specifies which axis ofthe operand should be split across the cluster. We split the LHS rows and RHScolumns.
Warning
A partitioned collective copy only completes the barrier passed in tocopy_gmem_to_smemin the leader block of the cluster! This is why you will see the kernel neverawaits the loads in the second block.
Secondly, as mentioned before, we additionally predicate the_compute body sothat only the leader block runs MMA instructions. Alltcgen05 calls additionallyget acollective_axis= argument, to indicate that the completion of MMAs shouldcomplete the barriers in both blocks in the cluster.
Finally, we apply a small modification to the epilogue. Even though the twoblocks in the cluster collectively compute a result of shape(cluster_tile_m,cluster_tile_n),each individual block only holds a result of shape(tile_m,cluster_tile_n).We change the output slicing code to need to slice out the rightout_gmem_window:
defmatmul3(a,b,config):...defkernel(...):...# Computeplgpu.barrier_wait(mma_done_barrier)out_m_index=m_index*2+lax.axis_index("cluster")out_m_slice=pl.ds(out_m_index*tile_m,tile_m)out_gmem_window=out_gmem.at[out_m_slice,n_slice]forniinrange(cluster_tile_n//config.epilogue_tile_n):......
4. Persistent kernel#
Our next step is to make the kernel persistent. This means that we’ll onlylaunch however many clusters we can actually run concurrently on the GPU (SMcount divided by 2), and we’ll have each cluster loop over a fixed number ofoutput tiles. This technique allows us to better amortize block(de)initialization costs (since they are only performed once on each SM) andachieve a small degree of overlap between the SMEM to GMEM copy in the epiloguewith the compute on the next output tile.
defmatmul4(a,b,config):...num_sms=jax.extend.backend.get_default_device().core_countf=plgpu.kernel(...grid=(num_sms//2,),grid_names=("cluster_grid",),...)
The change is relatively simple. We utilize theplgpu.nd_loophelper to specify that our iteration space is(m_iters,n_iters), but we alsorequest that it should be split accross the cluster grid using thecollective_axes=argument.
defmatmul4(a,b,config):...defkernel(...):is_lead_block=lax.axis_index("cluster")==0@plgpu.nd_loop((m_iters,n_iters),collective_axes="cluster_grid")def_mn_loop(loop_info:plgpu.NDLoopInfo):m_index,n_index=loop_info.indexm_slice=...n_slice=......# Compute + epilogue
The only meaningful modification in the compute portion of the kernel body isto ensure that the first few waits onconsumed_barriers in the memory warpare only skipped when processing the first output tile (as indicated byloop_info.local_index==0). When processing the second (or later) tile, the SMEM bufferswere used to compute the previous output tile, so we need to ensure that thosecomputations have completed before we overwrite them:
defmatmul4(a,b,config):...defkernel(...):...def_mn_loop(...):...@pl.core_map(plgpu.WarpMesh(axis_name="warp"))def_per_warp():warp_id=lax.axis_index("warp")@pl.when(warp_id==0)def_memory():def_loop_body(ki,_):slot=lax.rem(ki,max_concurrent_steps)@pl.when(jnp.logical_or(ki>=max_concurrent_steps,loop_info.local_index>0))def_():# Make sure the data has been consumed before overwriting.plgpu.barrier_wait(consumed_barriers.at[slot])
Finally, we modify the kernel epilogue by appending a single line:
defmatmul4(a,b,config):...defkernel(...):...def_mn_loop(...):...# Compute + epilogueplgpu.wait_load_tmem()# Load must complete before MMA can overwrite TMEM.
As the comment indicates, sinceTMEM loads are asynchronous,we must await their completion before we move on to the next output tile andoverwrite our TMEM allocation by issuing another MMA.
5. Dedicated epilogue warpgroup#
While persistence was useful by itself, it also unlocks another optimization.When the single Pallas thread in our kernel finishes the compute portion of thekernel, it performs the entire epilogue. However, this means that it can’t issueany more work for the TensorCore until it’s done!
This leads us to a simple solution: just use 2 Pallas threads (warpgroups)! Thefirst one will only focus on fetching the MMA operands and issuing the MMAoperations, while the second one will only perform the epilogue! Of course, toenable them to run concurrently, we need to double-buffer the TMEM used forthe accumulator, and use additional barriers to synchronize:
defmatmul5(a,b,config):...f=plgpu.kernel(...,num_threads=2,thread_name="wg",scratch_shapes=dict(...# Previously: plgpu.TMEM((tile_m, cluster_tile_n), jnp.float32, collective=True),acc_tmem=plgpu.TMEM((tile_m,2*cluster_tile_n),jnp.float32,collective=True),...# mma_done_barrier (now 2 barriers) + a new store_done_barrier (also 2 barriers)# Previously: plgpu.Barrier(num_arrivals=1, num_barriers=1, orders_tensor_core=True),mma_done_barrier=plgpu.Barrier(num_arrivals=1,num_barriers=2,orders_tensor_core=True),store_done_barrier=plgpu.ClusterBarrier(collective_axes=("cluster",),num_arrivals=1,num_barriers=2,orders_tensor_core=True,),),)
The kernel begins similarly to what we had before. We renamedacc_tmem toacc_tmem_slotsand switch between its halves as we step through the loop over the output tiles:
defmatmul(a,b,config):...defkernel(a_gmem,b_gmem,out_gmem,a_smem,b_smem,acc_tmem_slots,acc_smem,load_barriers,consumed_barriers,mma_done_barrier,store_done_barrier):wg_idx=lax.axis_index("wg")is_lead_block=...@plgpu.nd_loop(...)def_mn_loop(...):...acc_slot=lax.rem(loop_info.local_index,jnp.int32(2))acc_tmem=acc_tmem_slots.at[:,pl.ds(acc_slot*cluster_tile_n,cluster_tile_n)]...
The compute portion is additionally predicated onwg_idx==0. There are alsotwo important changes to how we use the barriers. First of all, if we want toreuse our TMEM allocation for MMA (which happens only forloop_info.local_index>=2),we need to wait on thestore_done_barrier for the TMEM half we want to reuse(as indicated byacc_slot). Secondly, once we want to request the TensorCoreto arrive on themma_done_barrier upon completion, we again need to select oneof the two barriers that corresponds to the currently used half of TMEM.
Warning
Note that even though only one of the blocks in the cluster issues MMAs, theyboth await thestore_done_barrier. This is only necessary, because arriving onthe same barrier twice without await in between sometimes leads to hardwareassertions.
defmatmul(a,b,config):...defkernel(...):...def_mn_loop(...):acc_slot=...acc_tmem=...@pl.when(wg_idx==0)def_compute_wg():@pl.core_map(plgpu.WarpMesh(axis_name="warp"))def_per_warp():warp_id=lax.axis_index("warp")@pl.when(warp_id==0)def_memory():...# Memory code remains unchanged# Wait for store to complete (except for the first two steps).@pl.when(jnp.logical_and(warp_id==1,loop_info.local_index>=2))def_wait_store():plgpu.barrier_wait(store_done_barrier.at[acc_slot])@pl.when(jnp.logical_and(warp_id==1,is_lead_block))def_compute():...# Compute loop remains unchangedplgpu.tcgen05_commit_arrive(mma_done_barrier.at[acc_slot],collective_axis="cluster")
Finally, we modify the epilogue, by only having the second warpgroup executeit, and by making the warpgroup signal the completion of the store by arrivingon thestore_done_barrier associated with the half of TMEM it used.
defmatmul(a,b,config):...defkernel(...):...def_mn_loop(...):...# Compute@pl.when(wg_idx==1)def_store_wg():...# Unmodified epilogueplgpu.wait_load_tmem()# Load must complete before we signal.plgpu.barrier_arrive(store_done_barrier.at[acc_slot])
6. Grid tiling#
Our final change to this kernel is to change the order in which we produce theoutput blocks to better utilize L2. As mentioned before, the compute units areextremely fast compared to the memory system and so we could use all the helpwe can get to try to keep them busy.
Note
This is trick goes by many different names. CUTLASS calls it “rasterization order”,ThunderKittens calls it “supergrouping”, while the Triton tutorials call it“program re-ordering”. We use the name “grid tiling”.
Our strategy for this is inspired by CUTLASS and works as follows. First, youselect which of the two dimensions in your iteration space is the faster changing(we call itgrid_minor_dim). Then, you select the tile size along that dimension(grid_tile_width). Instead of traversing the whole minor dimension of the gridbefore incrementing the more major index, we do it every time we traversegrid_tile_width elements. Once we run out of elements, we move on to the nexttile. But there’s a twist! Instead of jumping to the beginning of the second tile,we start from the end and work our way back. This ensures that as we switch thetiles, we can reuse some of the recent blocks of one of the operands.
Since this strategy is so common, we provide a helper for it:plgpu.planar_snake.When using the helper, the changes to the kernel are quite trivial:
defmatmul(a,b,config):...defkernel(...):...# We now only iterate over a 1D loop (but we still split it across clusters).@plgpu.nd_loop((m_iters*n_iters,),collective_axes="cluster_grid")def_mn_loop(loop_info:plgpu.NDLoopInfo):(lin_idx,)=loop_info.indexm_index,n_index=plgpu.planar_snake(lin_idx,# Linear index.(m_iters,n_iters),# The 2D iteration space.config.grid_minor_dim,# 0 or 1, indicates the fastest changing dim.config.grid_tile_width,# The width of tiles along the fastest changing dim.)...# Rest of the code remains unmodified
This simple trick isincredibly effectful and is crucial in achieving state ofthe art performance.
Final kernel#
You’ve reached the end of this tutorial, congratulations! In the previoussections, we focused only on the differences between the different kernels andrarely listed the complete source. This is useful to hide the irrelevant detailswhen extending the implementation, but it can also be helpful to see the fullsource. So here it is! The whole implementation is less than 150 lines andreaches SOTA performance (at least on the shape used in our benchmarks).
defmatmul6(a,b,config:TuningConfig):dtype=a.dtypem,k=a.shape_,n=b.shapetile_m,tile_n,tile_k=config.tile_m,config.tile_n,config.tile_kswizzle=plgpu.find_swizzle(tile_k*jnp.dtype(dtype).itemsize*8)swizzle_elems=swizzle//jnp.dtype(dtype).itemsizetransforms=(plgpu.TilingTransform((8,swizzle_elems)),plgpu.SwizzleTransform(swizzle))ifm%tile_m!=0:raiseValueError(f"{m=} must be divisible by{tile_m=}")ifn%tile_n!=0:raiseValueError(f"{n=} must be divisible by{tile_n=}")ifk%tile_k!=0:raiseValueError(f"{k=} must be divisible by{tile_k=}")cluster_tile_m=2*tile_mcluster_tile_n=2*tile_nm_iters=m//cluster_tile_mn_iters=n//cluster_tile_nk_iters=k//tile_kmax_concurrent_steps=config.max_concurrent_stepsdefkernel(a_gmem,b_gmem,out_gmem,a_smem,b_smem,acc_tmem,acc_smem,load_barriers,consumed_barriers,mma_done_barrier,store_done_barrier):wg_idx=lax.axis_index("wg")is_lead_block=lax.axis_index("cluster")==0@plgpu.nd_loop((m_iters*n_iters,),collective_axes="cluster_grid")def_mn_loop(loop_info:plgpu.NDLoopInfo):(lin_idx,)=loop_info.indexm_index,n_index=plgpu.planar_snake(lin_idx,(m_iters,n_iters),config.grid_minor_dim,config.grid_tile_width,)m_slice=pl.ds(m_index*cluster_tile_m,cluster_tile_m)n_slice=pl.ds(n_index*cluster_tile_n,cluster_tile_n)acc_slot=lax.rem(loop_info.local_index,jnp.int32(2))mn_acc_tmem=acc_tmem.at[:,pl.ds(acc_slot*cluster_tile_n,cluster_tile_n)]@pl.when(wg_idx==0)def_compute_wg():@pl.core_map(plgpu.WarpMesh(axis_name="warp"))def_per_warp():warp_id=lax.axis_index("warp")@pl.when(warp_id==0)def_memory():def_loop_body(ki,_):slot=lax.rem(ki,max_concurrent_steps)@pl.when(jnp.logical_or(ki>=max_concurrent_steps,loop_info.local_index>0))def_():# Make sure the data has been consumed before overwriting.plgpu.barrier_wait(consumed_barriers.at[slot])k_slice=pl.ds(ki*tile_k,tile_k)plgpu.copy_gmem_to_smem(a_gmem.at[m_slice,k_slice],a_smem.at[slot],load_barriers.at[slot],collective_axes="cluster",partitioned_axis=0)plgpu.copy_gmem_to_smem(b_gmem.at[k_slice,n_slice],b_smem.at[slot],load_barriers.at[slot],collective_axes="cluster",partitioned_axis=1)lax.fori_loop(0,k_iters,_loop_body,None)# Wait for store to complete (except for the first two steps).@pl.when(jnp.logical_and(warp_id==1,loop_info.local_index>=2))def_wait_store():plgpu.barrier_wait(store_done_barrier.at[acc_slot])@pl.when(jnp.logical_and(warp_id==1,is_lead_block))def_compute():def_loop_body(ki,_):slot=lax.rem(ki,max_concurrent_steps)plgpu.barrier_wait(load_barriers.at[slot])# Wait for data to arrive.plgpu.tcgen05_mma(mn_acc_tmem,a_smem.at[slot],b_smem.at[slot],consumed_barriers.at[slot],accumulate=(ki>0),collective_axis="cluster",)lax.fori_loop(0,k_iters,_loop_body,None)plgpu.tcgen05_commit_arrive(mma_done_barrier.at[acc_slot],collective_axis="cluster",)@pl.when(wg_idx==1)def_store_wg():# Ensure that copies from the previous mn step have completed.plgpu.wait_smem_to_gmem(0,wait_read_only=True)plgpu.barrier_wait(mma_done_barrier.at[acc_slot])out_m_index=m_index*2+lax.axis_index("cluster")out_m_slice=pl.ds(out_m_index*tile_m,tile_m)out_gmem_window=out_gmem.at[out_m_slice,n_slice]forniinrange(cluster_tile_n//config.epilogue_tile_n):acc_smem_ni=acc_smem.at[ni%2]ni_slice=pl.ds(ni*config.epilogue_tile_n,config.epilogue_tile_n)# Make sure that previous copy is done before we overwrite.plgpu.wait_smem_to_gmem(1,wait_read_only=True)acc_smem_ni[...]=plgpu.async_load_tmem(mn_acc_tmem.at[:,ni_slice]).astype(dtype)plgpu.commit_smem()plgpu.copy_smem_to_gmem(acc_smem_ni,out_gmem_window.at[:,ni_slice])plgpu.wait_load_tmem()# Load must complete before we signal.plgpu.barrier_arrive(store_done_barrier.at[acc_slot])plgpu.wait_smem_to_gmem(0,wait_read_only=True)num_sms=backend.get_default_device().core_countf=plgpu.kernel(kernel,out_shape=jax.ShapeDtypeStruct((m,n),dtype),grid=(num_sms//2,),grid_names=("cluster_grid",),cluster=(2,),cluster_names=("cluster",),num_threads=2,thread_name="wg",scratch_shapes=dict(a_smem=plgpu.SMEM((max_concurrent_steps,tile_m,tile_k),dtype,transforms=transforms),b_smem=plgpu.SMEM((max_concurrent_steps,tile_k,tile_n),dtype,transforms=transforms),acc_tmem=plgpu.TMEM((tile_m,2*cluster_tile_n),jnp.float32,collective=True),acc_smem=plgpu.SMEM((2,tile_m,config.epilogue_tile_n),dtype,transforms=transforms),load_barriers=plgpu.Barrier(num_arrivals=2,num_barriers=max_concurrent_steps),consumed_barriers=plgpu.Barrier(num_arrivals=1,num_barriers=max_concurrent_steps,orders_tensor_core=True,),mma_done_barrier=plgpu.Barrier(num_arrivals=1,num_barriers=2,orders_tensor_core=True),store_done_barrier=plgpu.ClusterBarrier(collective_axes=("cluster",),num_arrivals=1,num_barriers=2,orders_tensor_core=True,),))returnf(a,b)
