Writing Mosaic GPU kernels with Pallas
Contents
Writing Mosaic GPU kernels with Pallas#
This page is a reference for the most important features of the Pallas:MGPU backend.It’s not a tutorial and as such we do not expect everyone to read it top to bottom.Still, it is worth going overjust to familiarise yourself with some patterns you can find in other tutorials.
In the following examples, we’re going to assume the following imports are in scope:
importjax.experimental.pallasasplimportjax.experimental.pallas.mosaic_gpuasplgpu
What is a GPU?#
Technically, the NVIDIA GPU architecture looks as follows: the GPU is partitioned intostreaming multiprocessors (SMs). The way this manifests in the CUDA programming modelis that eachCUDA thread block (or CTA) is scheduled on exactly one SM, but multipleblocks can be scheduled onto a single SM at a time.
Each SM contains a chunk of fast memory calledshared memory (SMEM) and 4 subdivisions,each containing awarp scheduler and compute units (ALU, TensorCore, …).This is also reflected in the CUDA programs: eachwarp (a group of consecutive 32 CUDAthreads in a block) is assigned to one of those subdivisions in a round-robin fashion.Similarly to blocks, each warp is assigned to exactly one subdivision (it never migrates),but multiple warps can be assigned to the same SM subdivision. At each clock cycle, thewarp scheduler from each subdivision tries to select one of its resident warps to executethe next instruction.
Going further, recent CUDA versions also outline the concept of awarpgroup, which are4 consecutive warps. Knowing how the hardware looks like, we can see where this is comingfrom: 4 consecutive warps occupy the 4 quarters of an SM and let us issue instructionsthat utilize the whole SM.
Note
A GPU can be viewed in many different ways and in here we want to focus on a slightlysimplified model that is very TensorCore-centric. This should help you navigate thecomplexities of writing kernels involving the TensorCore, but keep in mind that thereal picture is more complicated.
For our purposes, TensorCore operations have grown so big that it no longer makes muchsense to follow the CUDA model. As such, to us, a GPU is a collection of single-threaded cores(SMs) with one thread of Pallas:MGPU corresponding to a CUDA warpgroup. In this model, eachoperation you perform in the kernel occupies the whole CUDA warpgroup, and its constituentwarps always run in lockstep (modulo the jitter from hardware scheduling) and never takedifferent paths through control flow (with the small exception ofcore_map that we willdiscuss later). One notable addition here is that we still allow you to co-schedule multipleof those Pallas-level threads on the same SM so that they can cooperate and communicatethrough shared memory (we realize that by putting them in the same CUDA block).
Note
From now on, whenever we say “thread”, we refer to the Pallas thread, not a CUDA thread/lane.
Note
This is very similar to a programming model popularized byTriton,but as you will see there are a few differences. Mosaic GPU tends to be more low level,which usually means you will have to put in more work, but it also puts you more in control.In our view both approaches have their merits and we encourage you to pick the backend thatsuits your needs the best! Pallas supports and will continue to support Triton as an alternativeGPU backend.
In-order execution & using multiple hardware units#
Unlike more complicated CPU architectures GPU only support in-order execution. That, however,does not mean that at any given time only a single instruction is running! Each SM quarterhas multiple independent functional units: TensorCore, Arithmetic logic unit (ALU),Load/Store (LSU), Special function unit (SFU). If the first instruction targets one of theunits and is followed by another one (that does not use the result of the first one), then thewarp scheduler can issue the second one before the first one completes. This is often referredto as instruction-level parallelism (ILP) and is a common theme in modern TensorCore kernels:TensorCore operations are so big and take so many cycles to complete, that it is a waste to nottry to use other units in the meantime.
To extend this even further, we can take advantage of this hardware-unit-level parallelism byallowing multiple Pallas threads to run concurrently. If one of the threads primarilyoccupies the ALU, while another one primarily issues TensorCore related instructions, we cantake advantage of the efficient context switching built into the warp schedulers to keep bothunits busy. This is one of the core idea behind algorithms such asFlashAttention 3orCUTLASS ping-pong matmul kernels.
For more information on how warp scheduling and instruction issue works, we recommend readingAnalyzing Modern NVIDIA GPU cores.
Memory spaces#
The GPU features a few different memory spaces that can be totally ordered from largest (interms of capacity) and slowest (in both total bandwidth and latency of a single access).
The biggest memory space isplgpu.GMEM, forglobal memory. In recent data-center grade GPUsthis memory space is often measured in tens or even hudreds of gigabytes, but it is also theslowest one.
The next memory space, used for the L2 cache, is also more or less global in thesense that it is shared by the whole GPU, but its use can only be influenced indirectly throughcache hints. As such, there’s no way to manually place values in there and so this memory spaceis not exposed in Pallas:MGPU. While only about a 100MB in size, this memory has considerablyhigher bandwidth than GMEM, and so it is still often recommended to take advantage of it whilewriting high-performance kernels.
Next in line isshared memory, orplgpu.SMEM. This memory is located directly inside each SMand so it is partitioned. Unless block clusters are used (see the section of clusters below),each block is only allowed to access its own SMEM allocations.
Finally, the lowest level memory space is theregister memory. This is where every single value(i.e. JAX array) in a Pallas kernel will be located. If the compiler runs out of registers tostore those arrays, it will insertspills, meaning that it will periodically store and reloadvalues to memory. Those spills often introduce other significant performance degradations and sowe recommend avoiding them. The warning messages about spills can be clearly seen in theptxasmessages during kernel compilation. To make them visible, run withMOSAIC_GPU_DUMP_PTXAS=1in your environment.
The Blackwell GPU generation, has one additional memory space calledtensor memory orplgpu.TMEM.TMEM is very similar to register memory, only it is explicitly allocated and managed by you.It is used to store the MMA accumulator, operand metadata (for sparsity or scaling),and optionally the left MMA operand. See the Blackwell MMA section for more information about TMEM.
Requesting/allocating memory in specific memory spaces#
Kernel inputs or outputs are placed in SMEM by default. If you want to access them as GMEM referencesaddmemory_space=plgpu.GMEM to theirBlockSpec. If you want the kernel to be called with the wholeinput or output array in GMEM, it is sufficient to specifyBlockSpec(memory_space=plgpu.GMEM).
SMEM andTMEM can be allocated explicitly in thescratch_shapes argument ofpl.pallas_call,or usingpl.run_scoped. To allocate a reference, simply call the memory space object with therequested shape and dtype. For example:plgpu.SMEM((128,128),jnp.float16) will allocate a 128x128array of float16 elements in shared memory.
Taking advantage of the L2 cache#
While the L2 cache cannot be managed manually, its noticeably higher bandwidth compared to globalmemory makes it worth thinking about. The simplest way to take advantage of it, is to reorderthe parallel grid dimensions so that invocations that are scheduled in similar time periods alsoaccess the same input data.
While the CUDA programming model does not guarantee anything about the order in which the blocksare assigned to SMs, in recent generations the heuristic seems to simply iterate over the(x,y,z) CUDA grids in column-major order (i.e.x is the fastest-changing dimension andz is the slowest). Similarly, Pallas:MGPU does not guarantee how a user-specified grid is mapped tothe CUDA grid (Pallas supports grids of arbitrary rank, not just up to 3D). However, you can assume thatthe iteration will happen inrow-major order. That is, if a grid has dimensions(a,b), thenb will be the fastest-changing dimension anda will be the slower one.
To give a practical example of this, consider a plain matrix multiplication kernel. There, oneusually uses two parallel grid dimensions(m,n), corresponding to tiling the two non-contractingdimensions. If we use this simple scheme, in Pallas:MGPU all programs with id(0,...) will bescheduled before any block with id(1,...). And, collectively, the programs withm=0 have toread all of theB operand! If then ork dimensions are very large, there is no chance thatwe’ll be able to get cache hits from the(1,...) programs from accesses made by the(0,...)programs. For simplicity, assuming we can only run 16 blocks at a time, we see this access patternfrom the first scheduled wave:
However, if we simply rearrange the grid to be(m//mt,n,mt) (and then replacepl.program_id(0)withpl.program_id(0)*mt+pl.program_id(2) in the kernel), it is straightforward to see that aband of programs along both dimensions will be scheduled concurrently (instead of scheduling a singlerow). This greatly increases the number of concurrent programs that load similar slices of data,usually significantly improves the L2 utilization and hence the overall performance of the kernel(if it was memory bound). Continuing our example with 16 blocks and usingmt=4, we get the followingaccess pattern:
Note that even though the number of active blocks hasn’t changed, the total footprint of the data theyaccess has halved! We get a much higher chance of getting L2 hits now.
Array layouts and memory reference transforms#
In Pallas, the data structures you work with (arrays and references) have alogical shape (e.g., a 128x128 matrix). Thislogical shape must be mapped to aphysical representation (how the data isactually represented in the GPU’s memory). The specific mapping depends on where thedata resides:
Array Layouts: Arrays are stored in register memory and we call this mappingalayout. Layouts define how the elements of an array aredistributed across the registers available to the CUDA lanes that form a Pallas thread.
Memory Reference Transforms: For mutable references pointingto
SMEM, this mapping is called atransform.Transforms describe how the logical data structure is arranged within thatblock of memory.
These concepts are crucial for performance, especially when interacting withspecialized hardware units like TensorCores or optimizing memory accesspatterns.
Note
We are working on a mode that will deal with assigning layouts and transforms fullyautomatically (although with way to provide hints and more control). The APIs listedbelow will likely continue to function, but will become optional.
Memory reference transforms#
Transforms are applied when a memory reference is first allocated. Pallasprimitives that operate on these references will automatically account for theirassociated transforms.
defbody(...,scratch_ref):# Asynchronous copy will reformat the GMEM data to match the SMEM transformsplgpu.copy_gmem_to_smem(...,scratch_ref,barrier)plgpu.barrier_wait(barrier)plgpu.wgmma(...,scratch_ref)# wgmma only accepts properly transformed refs...
There are two ways in which references are allocated and each has a way to selectthe desired transforms:
1. Usingplgpu.BlockSpec
transforms=(plgpu.TileTransform((8,64)),plgpu.SwizzleTransform(128))f=pl.pallas_call(in_specs=plgpu.BlockSpec(in_block_shape,in_index_map,transforms=transforms),out_specs=plgpu.BlockSpec(out_block_shape,out_index_map,transforms=transforms),...)
Note that unlikeplgpu.BlockSpec,pl.BlockSpec doesnot allow specifyingtransforms.
2. Specifying thetransforms argument on the allocatedSMEM
transforms=(plgpu.TileTransform((8,64)),plgpu.SwizzleTransform(128))f=pl.pallas_call(scratch_shapes=plgpu.SMEM((128,128),jnp.float16,transforms=transforms),...)
The available transforms are:
plgpu.TileTransform(tile_shape), which organizes the data into contiguous,non-overlapping tiles of shapetile_shape. The data of one tile is alwaysfully linearized (row-major), before another tile begins (tiles are alsotraversed in row-major order). As an example, applyingTileTransform((8,64))to a(128,128)reference means the data corresponding to the logicalslice[0:8,0:64]will be stored first (row-major), followed by[0:8,64:128],[8:16,0:64],[8:16,64:128], and so on. A different way to achievethis would be to take the input arrayxand traversex.reshape(128//8,128//64,8,64).transpose(0,2,1,3)in row-major order.plgpu.SwizzleTransform(swizzle_in_bytes), which transforms the data as described in thePTX docs andCUDA docs.Swizzling is useful, because it allows transferring data in MMA-related layoutsbetween register and shared memory without bank conflicts. The exact detailsof how the memory looks like after swizzlingare not that important, sinceall primitives will account for it automatically. Note that the swizzle amountis specified in bytes (only 128, 64, 32 and 16 are supported), and is usuallyaccompanied by aTileTransform(which uses elements in its shape!).plgpu.TransposeTransform(permutation), which permutes the dimensions of the array before it is linearized.This is primarily useful in that it lets you change the layout during the GMEM-SMEM copies (onlydo keep in mind that changing the minormost/last dimension is not supported by the hardware).
Array layouts#
There are a few useful layouts we have defined for you so far:
plgpu.Layout.WGMMA, which is the layout in which the Hopper-generation TensorCoreexpects the MMA accumulator or 16-bit input operands to have in registers.plgpu.Layout.WGMMA_ROW, which is the layout obtained after the above after reducingit along the rows. Re-broadcasting the rows is free and will produce a value withWGMMAlayout.plgpu.Layout.WGMMA_COL, which is an analogue of the one above, only reduced alongcolumns instead of rows.plgpu.Layout.WG_STRIDED, where the value is partitioned equally among the 128CUDA lanes making up a Pallas thread. The consecutive elements (after vectorization)are assigned to the lanes in a round-robin fashion. Very simple and effective whenno interaction with TensorCores is needed.plgpu.Layout.WG_SPLAT, indicating that the value is constant. Each CUDA lane willhold a single register that contains the value. You normally never have to interactwith this layout, as it is implicitly used when constant values are created andis always implicitly convertible to other layouts.
At the moment, in the default mode of operation, array layout propagation happensonly in a forward direction and there is little implicit support for reconcilinglayout conflicts: only splat layouts can be implicitly converted into any otherlayout. If you e.g. try to add two arrays that have a different layout, the loweringwill complain and fail. There are very limited facilities that let you convert betweenlayouts, and we usually recommend storing the value to SMEM and reading it back inthe target layout.
MMA (TensorCore)#
In this section, we focus on how Pallas:MGPU kernels can utilize the TensorCore unit.The programming interface of the TensorCore changes significantly between differentNVIDIA GPU generations, which is why the lowest-level interfaces differ in Pallas:MGPU as well.
Each MMA operation is associated with three operands:
the accumulator
Dof shape(M,N),the left input
Aof shape(M,K),the right input
Bof shape(K,N).All operands must have the same element type.
Each use of MMA involves a few steps:
Allocating the space for the accumulator (MMA implicitly performs
D+=A@B)Preparing the
AandBoperandsIssuing the operation
Waiting for the operation to complete
Reading out the result
Steps 2.-4. are usually performed in a loop over the contraction dimension (K).
Memory space ofA andB operands#
TheA andB operands are generally best passed in through SMEM, where they canbe conveniently loaded usingplgpu.copy_gmem_to_smem. For those operands to becompatible with MMA operations, they need to have the appropriate tiling and swizzlingtransforms specified upon their allocation. For all currently supported generations,the TensorCore requires the data to be laid out into row-major 2D tiles of shape(8,swizzle_elems), whereswizzle_elems is derived by dividing the swizzle by theelement type bytewidth. The currently supported swizzles are: 128, 64, and 32. Largerswizzles are preferable as they improve the performance of GMEM-to-SMEM copies.
defmma_transforms(shape_dtype:jax.ShapeDtypeStruct):assertlen(shape_dtype.shape)==2ifshape_dtype.shape[0]%8:raiseValueError("Number of rows must be divisible by 8")forswizzle_bytesin(128,64,32):swizzle_elems=swizzle_bytes//shape_dtype.dtype.itemsizeifshape_dtype.shape[-1]%swizzle_elems==0:return(plgpu.TilingTransform((8,swizzle_elems)),plgpu.SwizzleTransform(swizzle_bytes))raiseValueError("Failed to find transforms for the specified window type")
If the operands need to be transformed, theA operand can be passed in through a differentmemory space (architecture dependent, see below). TheB operandmust be located in SMEM.
Transposed operands#
When performing MMA on 16-bit operands, the TensorCore can automatically transpose theinput data. For example, theA reference is allowed to be of shape(K,M), but ithas to be transposed before passing it into the mma function. For example:
assertacc_ref.shape==(M,N)anda_ref.shape==(K,M)andb_ref.shape==(K,N)a_ref_t=plgpu.transpose_ref(a_ref,(1,0))asserta_ref_t.shape==(M,K)# The shape expected by plgpu.wgmmaplgpu.wgmma(acc,a_ref_t,b_ref)
An analogous operation is allowed on theB reference in this case too.
Hopper (wgmma)#
In this section, we cover the basics of using the Hopper-generation TensorCores, exposed inPTX as thewgmma.mma_async instruction.
Allocating the accumulator#
In the Hopper hardware architecture the accumulator is allocated in registers, but in Pallasit is modeled as a mutable reference, as each MMA operation accumulates in-place.There are two ways to allocate the accumulator.
To create a zero-initialized accumulator you can usepl.run_scoped with aplgpu.ACC((m,n),dtype) type.
defcompute(acc_ref):...returnacc_ref[...]output=pl.run_scoped(compute,plgpu.ACC((m,n),jnp.float32))
Dereferencing the accumulator reference, as seen in the end of thecompute function willimplicitly await all outstanding WGMMA operations.
If you’d like to initialize it with an existing array, you can usepl.run_state withplgpu.ACC.init(init_array):
defcompute(acc_ref):...return# pl.run_state only returns the final value of the accumulatoroutput=pl.run_state(compute)(plgpu.ACC.init(init_array))
Ifpl.run_state has accumulator operands, it implicitly awaits all outstanding WGMMAoperations before returning the final values.
Preparing theA andB operands#
As discussed above, we recommend passing inA andB through shared memory. In thiscase the correct tiling and swizzling transforms must be specified.
plgpu.wgmma additionally allows passing inA through registers (i.e. not an SMEMreference but as a regular JAX array). This mode, however, comes with a number ofsignificant drawbacks and it is very difficult to ensure sufficient synchronization tomake this safe.
TODO: Explain the conditions under which it is acceptable to do this.
Issuing the operation#
The supported MMA shapes are such that:
Mis divisible by 64Nis divisible by 8 and not greater than 256Kis a multiple ofswizzledivided by the operand’s element type bytewidth
The currently supported data types are:jnp.float32,jnp.bfloat16 andjnp.float16.The accumulatorD must be ajnp.float32, with the exception ofjnp.float16 inputs,in which case it is allowed to bejnp.float16 as well.
Waiting for the operation to complete#
Eachplgpu.wgmma call implicitly synchronizes with all previousplgpu.wgmma calls, suchthat once control returns from it, we guarantee that no WGMMA other than the last issuedone is still running. As such, any SMEM regions that were read by previously issued WGMMAinstructions can be reused. This is especially relevant for pipelining WGMMA with async memory copies:
buffers=3# In reality you might want even moreasserta_smem.shape==(buffers,m,k)assertb_smem.shape==(buffers,k,n)assertacc_ref.shape==(m,n)deffetch_a_b(ki,slot):a_slice=...# Replace with the right M/K sliceb_slice=...# Replace with the right K/N sliceplgpu.copy_gmem_to_smem(a_gmem.at[a_slice],a_smem.at[slot],a_loaded.at[slot])plgpu.copy_gmem_to_smem(b_gmem.at[b_slice],b_smem.at[slot],b_loaded.at[slot])defloop_body(i,_):slot=jax.lax.rem(i,buffers)plgpu.barrier_wait(a_loaded.at[slot])plgpu.barrier_wait(b_loaded.at[slot])plgpu.wgmma(acc_ref,a_smem.at[slot],b_smem.at[slot])# We know that only the last issued WGMMA is running, so we can issue a async load in# into the other bufferload_i=i+buffers-1load_slot=jax.lax.rem(load_i,buffers)@pl.when(jnp.logical_and(load_i>=buffers,load_i<num_steps))def_do_fetch():fetch_a_b(load_i,slot)forslotinrange(buffers):fetch_a_b(slot,slot)jax.lax.fori_loop(0,num_steps,loop_body,None)
Blackwell (tcgen05)#
The Blackwell generation has significantly redesigned the TensorCore subunit.It is now significantly more independent from the regular warp schedulers andno longer uses or even supports using registers as its operands. In their place,a new memory space calledtensor memory (TMEM) has been introduced. What’smore TensorCores from pairs of SMs can now pool their resources and computelarger MMA operations that span both SMs. We call this a“collective MMA operation”.
Allocating the accumulator / Using TMEM#
TMEM references can be allocated in the same way in which all other referencesare allocated—usingpl.run_scoped:
@functools.partial(pl.run_scoped,tmem_ref=plgpu.TMEM((128,128),jnp.float32))defbarrier_scope(tmem_ref):...
Not all shapes can be allocated in TMEM. Only 2D references are supported, andthe number of rows (the size of the first dimension) must be 128 or 64 at themoment.
What’s more, if the data type has a bitwidth smaller than 32-bits, it is necessaryto declare if the allocation is supposed to be packed (e.g. putting two 16-bitelements into a single 32-bit cell in TMEM) or not (with each element padded upto 32-bits). MMA accumulators (fp32 or fp16) are never packed, but if the leftoperand it passed in TMEM, it must always be packed:
@functools.partial(pl.run_scoped,acc_ref=plgpu.TMEM((128,128),jnp.float16,packed=False),lhs_ref=plgpu.TMEM((128,128),jnp.float16,packed=True))defbarrier_scope(acc_ref,lhs_ref):plgpu.tcgen05_mma(acc_ref,lhs_ref,rhs_smem_ref,...)...
Another interesting complication with TMEM is that all operations on it are asynchronous.For that reason, reads and writes using the Python subscript syntax that are normallyused e.g. for SMEM are not allowed for TMEM.
Loads#
Loads can be performed usingplgpu.async_load_tmem and awaited usingplgpu.wait_load_tmem:
smem_ref[...]=plgpu.async_load_tmem(tmem_ref)plgpu.commit_smem()plgpu.copy_smem_to_gmem(smem_ref,gmem_ref)plgpu.wait_smem_to_gmem(0)plgpu.wait_load_tmem()# Wait for the read to fully complete before we overwrite tmem_ref again.
The load semantics are quite confusing, in that the array returned from the loadcan be safely used without any additional synchronization. However, if the readTMEM region is ever overwritten again (e.g. by a store or an MMA operation), thethread that issued the load must first callplgpu.wait_load_tmem() to ensurethe program remains race-free.
Note
One way to make peace with this seemingly causality-breaking behavior (dataarrives in registers before it is fully read from TMEM) is to consider that itmight be an effect of an interaction of a limitation and a convenience featurein the PTX compiler. We don’t know if this is true, but at least it makes sense.
The convenience feature is that the compiler can reliably track the usage ofregisters produced by TMEM loads and will insert the minimum number of delaysnecessary to ensure the data arrives from TMEM before it’s used. The readoperation is unrolled into many instructions, meaning that they don’t have toall be awaited before we start consuming the registers filled in by the first load.This is why we don’t need to guard the use of the result.
The limitation is that the compiler cannot reliably perform alias analysis onTMEM loads and stores, which is why any load and store that is not separatedby an explicit wait is considered safe to execute concurrently. The alternativewould unnecessarily pessimize the performance of loads and stores that are trulyunrelated. This is why we need to explicitly wait before we reuse TMEM again.
Stores#
Conversely, stores are performed usingplgpu.async_store_tmem and awaited usingplgpu.commit_tmem:
plgpu.async_store_tmem(tmem_ref,smem_ref[...])plgpu.commit_tmem()smem_ref2[...]=plgpu.async_load_tmem(tmem_ref)# Safe to read from tmem_ref now
Preparing theA andB operands#
We recommend passing inA andB through shared memory. In this case thecorrect tiling and swizzling transforms must be specified.TheA operand can be passed in as a TMEM reference as well, but it must be packed.
Issuing the operation#
The supportednon-collective MMA shapes are such that:
Mis 64 or 128Nis divisible by 8 and not greater than 512Kis a multiple of8*swizzledivided by the bitwidth of element type
The supportedcollective MMA shapes are such that:
Mis 128 or 256 (half of that per block)Nis divisible by 8 and not greater than 256 (not greater than 128 in each block)Kis a multiple of8*swizzledivided by the bitwidth of element type
The currently supported floating-point data types are:jnp.bfloat16,jnp.float16,jnp.float8_e5m2,jnp.float8_e4m3fn. The accumulator can beajnp.float32 orjnp.float16, with the exception ofjnp.bfloat16 when itmust be ajnp.float32.
The only currently supported integer data type isjnp.int8 with ajnp.int32accumulator.
Note
According to our benchmarks, here are some performance rules-of-thumb:
Non-collective MMA should always use M=128 and N >= 128.
M=64 causes a significant performance drop.
N=64 causes a noticeable performance drop, but not as significant as M=64.
Collective MMA is always reasonably fast, but not faster than non-collective MMA.
The biggest benefit from collective MMA is not higher TensorCore throughputbut the ability to share data between SMs, allowing to increase the arithmeticintensity of the kernel.
Swizzle and transposes do not seem to affect performance in a significant way.
Waiting for the operation to complete#
Awaiting the result of aplgpu.tcgen05_mmacall requires the use of aBarrier. We recommend reading through the referencedocumentation forBarriers, and especially itsBlackwell-related subsection for more information.
If the barrier is passed in directly totheplgpu.tcgen05_mma,completing a wait on that barrier will indicate that the final accumulator hasbeen written to TMEM. For example:
@functools.partial(pl.run_scoped,barrier_ref=plgpu.Barrier(orders_tensor_core=True))defbarrier_scope(barrier_ref):plgpu.tcgen05_mma(acc_tmem,lhs_ref,rhs_ref,barrier_ref,accumulate=False)plgpu.barrier_wait(barrier_ref)# We can read the result now.result=plgpu.async_load_tmem(acc_tmem)...
If no barrier is given toplgpu.tcgen05_mma,its completion will be tracked only onceplgpu.tcgen05_commit is called:
@functools.partial(pl.run_scoped,barrier_ref=plgpu.Barrier(orders_tensor_core=True))defbarrier_scope(barrier_ref):plgpu.tcgen05_mma(acc_tmem,lhs_ref,rhs_ref,accumulate=False)plgpu.tcgen05_mma(acc_tmem,lhs_ref2,rhs_ref2)plgpu.tcgen05_commit(barrier_ref)plgpu.barrier_wait(barrier_ref)# We can read the result now. Both MMAs have completed.result=plgpu.async_load_tmem(acc_tmem)...
Collective MMA#
The Blackwell generation gains a new way to perform MMA operations, where theTensorCores of 2 SMs in a cluster collaborate on a single MMA operation. TheB operand from each SM is shared with the other. TheD andA operands arelocal to each SM and not shared.
This means that to perform a collective MMA with shape M, N, and K, the operandsin each of the two Pallas threads should be of sizes:(M//2,K) forA,(K,N//2) forB and(M//2,N) forD (the accumulator). Stacking thetwo accumulators on top would recover the result of performing a MxNxK matrixmultiplication.
To make loading of theB operand easier,plgpu.copy_gmem_to_smemcan be used together withcollective_axes andpartitioned_axis to indicatethat the two Pallas threads along the collective axis should load the same slice,but each will only obtain half of it. Unlike a copy withcollective_axes aloneit does not utilize TMA multicast (since each thread loads a distinct slice ofdata), but it can simplify the indexing logic a bit.
plgpu.copy_gmem_to_smem(b_gmem,# [K, N]b_smem,# [K, N // 2]b_tma_barrier,collective_axes="x",partitioned_axis=1,)
Usingcore_map#
pl.pallas_call is suitable for kernels where a single Pallas thread canperform the whole computation for an entire CUDA block. Thepl.core_mapfunction relaxes this restriction, allowing for using multiple threads within asingle block (e.g. for warp specialization) or across multiple blocks in a blockcluster (e.g. to utilize multicast TMA).
Replacingpl.pallas_call withpl.core_map orplgpu.kernel#
Let us begin with a simple Pallas kernel that increments an array:
@functools.partial(pl.pallas_call,grid=(2,),in_specs=[pl.BlockSpec(block_shape=(128,),index_map=lambdai:(i,))],out_specs=pl.BlockSpec(block_shape=(128,),index_map=lambdai:(i,)),out_shape=jax.ShapeDtypeStruct((256,),jnp.float32),# Total output shape)defrun_kernel(x_ref,y_ref):# x_ref and y_ref are in SMEM!y_ref[...]=x_ref[...]+1x=jnp.arange(256,dtype=jnp.float32)y=run_kernel(x)np.testing.assert_array_equal(y,x+1)
We can write a similar kernel usingpl.core_map. One big difference is thatunlikepl.pallas_call, no GMEM<->SMEM copies will be inserted automatically.If you want them, you can either insert them yourself or use theplgpu.emit_pipelinehelper. We recommend reviewing thesoftware pipelining guide.
@pl.run_statedefrun_kernel(refs):x_ref,y_ref=refs# Here, we're not in the kernel yet! pl.run_state simply changes the JAX# immutable arrays into mutable GMEM (not SMEM!) references.# Define the mesh: 2 CUDA blocks over 1 axis called "x"mesh=plgpu.Mesh(grid=(2,),grid_names=("x",))@pl.core_map(mesh)# core_map executes the bodydefkernel_body():# Once we enter the pl.core_map scope, we are in the body of the kernel.block_slice=pl.ds(jax.lax.axis_index("x")*128,128)y_ref[block_slice]=x_ref[block_slice]+1x=jnp.arange(256,dtype=jnp.float32)y_init=jnp.zeros_like(x)_,y=run_kernel((x,y_init))np.testing.assert_array_equal(y,x+1)
Whilepl.core_map is a powerful API, it is also quite low-level and is prettymuch always used in underpl.run_state (to make JAX arrays into refs) orpl.run_scoped (to allocate for scratch refs). For that reason, we alsoprovide a convenience APIplgpu.kernel:
@functools.partial(plgpu.kernel,out_shape=jax.ShapeDtypeStruct((256,),jnp.float32),grid=(2,),grid_names=("x",),)defrun_kernel(x_ref,y_ref):# x_ref and y_ref are in GMEM!block_slice=pl.ds(jax.lax.axis_index("x")*128,128)y_ref[block_slice]=x_ref[block_slice]+1x=jnp.arange(256,dtype=jnp.float32)y=run_kernel(x)# No need to preallocate outputs as in pl.core_map.np.testing.assert_array_equal(y,x+1)
Note
Theplgpu.Mesh used withpl.core_map defines a topology for computationwithin a single GPU, specifying how work is distributed across CUDA blocks(thegrid), Pallas threads within a block (num_threads), and potentiallyCUDA block clusters (cluster). This is analogous to howjax.sharding.Meshdefines a topology for distributed computationacross multiple devices in JAX.Both involve SPMD programs executing across the defined topology. Furthermore,you can run “collectives” over the Pallas threads and cluster (e.g., usingplgpu.ClusterBarrier or collective async copies), similar to how JAXcollectives (psum,all_gather, etc.) operate across devices in a JAXMesh.Both also use named axes, andjax.lax.axis_index(axis_name) can be used to geta thread’s or block’s coordinate.
Using multiple Pallas threads per CUDA block#
Below, you can find an example of two Pallas threads within a single blocksynchronizing through a barrier and even exchanging data through SMEM.
x=jnp.arange(128,dtype=jnp.float32)@functools.partial(plgpu.kernel,out_shape=x,scratch_shapes=dict(smem_ref=plgpu.SMEM(x.shape,x.dtype),barrier_ref=plgpu.Barrier(),),num_threads=2,thread_name="pallas_thread",)defrun_kernel(x_ref,y_ref,smem_ref,barrier_ref):thread_id=jax.lax.axis_index("pallas_thread")@pl.when(thread_id==0)defproducer_thread():smem_ref[...]=x_ref[...]+1plgpu.barrier_arrive(barrier_ref)# Signal the consumer thread@pl.when(thread_id==1)defconsumer_thread():plgpu.barrier_wait(barrier_ref)# Wait for the producer threadout_ref[...]=smem_ref[...]+1y=run_kernel(x)# There's no need to preallocate the input anymore.np.testing.assert_array_equal(y,x+2)
While this example is simple, you can find a more complicated example in thesynchronization section.
Multiple threads are frequently used in high-performance kernels such as thelatest flash attention variants or ping-pong matrix multiplication. In both ofthose, there are 2 compute threads in the program that use the SM’s ALUand TensorCore in an alternating fashion to ensure no execution conflicts.
Another common technique is to allocate one Pallas thread and devote it entirelyto scheduling asynchronous copies for data consumed by other threads. Whileimplementing this scheme from scratch can be complicated, we provide aconvenient helper API:plgpu.emit_pipeline_warp_specialized.
Using CUDA block clusters#
The kernel below launches a single cluster of 2 CUDA blocks and uses the TMAmulticast feature to collectively perform a copy of GMEM into SMEM of bothblocks. All blocks participating in the collective copy must schedule the exactsame copy for the program to be valid.
@functools.partial(plgpu.kernel,out_shape=jax.ShapeDtypeStruct((2,128),jnp.float32),scratch_shapes=dict(smem_ref=plgpu.SMEM((128,),jnp.float32),barrier_ref=plgpu.Barrier(),),cluster=(2,),cluster_names=("cluster",),)defrun_kernel(x_ref,y_ref,smem_ref,barrier_ref):# Specifying collective_axes will enable TMA multicast automatically.plgpu.copy_gmem_to_smem(x_ref,smem_ref,barrier_ref,collective_axes="cluster")plgpu.barrier_wait(barrier_ref)plgpu.copy_smem_to_gmem(smem_ref,o_ref.at[jax.lax.axis_index("cluster")])plgpu.wait_smem_to_gmem(0)x=jnp.arange(128,dtype=jnp.float32)y=run_kernel(x)# Each block gets the same data and writes it out.np.testing.assert_array_equal(y,jnp.stack([x,x],axis=0))
Collective allocations inpl.run_scoped#
When usingpl.core_map with multiple Pallas threads (i.e.,num_threads>1inplgpu.Mesh), allocations made viapl.run_scoped (for SMEM or Barriers)must be performedcollectively by all threads. This is indicated by specifyingacollective_axis argument to therun_scoped, which has two effects:
it promises that all threads will call the same allocation, and
all threads will receive the exact same allocation.
If collective_axes is not specified or does not include the Pallas thread axis,each thread would get its own private copy of the scratch variable. This isusually undesired and not supported at the moment.
Global (grid-wide) allocations usingpl.get_global#
Sometimes, it is useful to allocatesemaphores in a way that enables them to beshared by all the parallel program instances. For example, when the number ofparallel instances is small enough that the kernel is persistent. Such allocationsare possible usingpl.get_global:
defbody(out_ref):sem_ref=pl.get_global(plgpu.SemaphoreType.REGULAR)block_id=lax.axis_index("x")@pl.when(block_id==0)def_():pl.semaphore_signal(sem_ref)# Block 0 signals@pl.when(block_id==1)def_():pl.semaphore_wait(sem_ref)# Block 1 waitsout_ref[...]=jnp.ones_like(out_ref)out_shape=jax.ShapeDtypeStruct((128,),jnp.float32)plgpu.kernel(body,out_shape=out_shape,grid=(2,),grid_names=("x",))()
Synchronization structures and primitives#
In this section, we go over the most important functions and data structuresused for synchronization between threads and also some asynchronous operations.
commit_smem#
Regular reads/writes to references are guaranteed to produce values consistentwith the sequential program order. For example, in the following program, it isguaranteed thatvalue is equal tovalue2.
ref[...]=valuevalue2=ref[...]
This guarantee, however, does not extend to asynchronous primitives such as asynccopies or MMA operations. To make the SMEM writes visible to those primitives, youare required to explicitly synchronize with them using theplgpu.commit_smem() function.
For example:
smem_ref[...]=valueplgpu.commit_smem()plgpu.copy_smem_to_gmem(smem_ref,...)
or:
smem_ref[...]=valueplgpu.commit_smem()plgpu.wgmma(smem_ref,...)
This explicit synchronization is also required in the other direction, forexample:
v=plgpu.load(smem_ref,())plgpu.commit_smem()plgpu.copy_gmem_to_smem(...,smem_ref,...)
Failing to call this function is likely to cause subtle data races, due to those asynchronoushardware units reading stale data from SMEM. Unfortunately, this function is relatively expensive,which is why we rely on you, the user, to insert it in the minimal number of places where it’s necessary.
Barrier#
This is essentially a thin wrapper around an array of PTXmbarrier types and ispassed in as a reference. All functions involving barriers expect to only get a singlebarrier argument, and so if the reference contains multiple, you have to extract oneof them explicitly usingbarriers.at[index].Barriers are always allocated in SMEMand as such have relatively low overheads. Each barrier can be configured to completeafter a fixed number of “arrivals” (by default 1).
To block a thread until a barrier completes, use the following function:
plgpu.barrier_wait(barrier)
Warning
It is critical to ensure that the synchronization scheme makes it impossible for twobarrier completions to happen without a call toplgpu.barrier_wait in between them.For example, if you useBarriers to synchronize two producer/consumer threads, youneed to perform barrier synchronization going both ways to introduce “backpressure”that will stop one thread from arriving twice before the other one had a chance to await.Failing to satisfy this will corrupt the data structure and can cause surprising failures(including CUDA runtime errors). See below for an example of a valid program with two threads.
Warning
Another critical restriction is that the number of barrier completions must equal thenumber of barrier waits throughout the barrier’s lifetime. It is not allowed to end a scopedallocation of a barrier when it has an unawaited completion. Otherwise, when it isreused by the compiler, leaving it in this state can cause problems downstream.
Warning
Finally, it is crucial to ensure that each thread that ever waits on aBarriertakes part in allwait operations on it. It is not allowed to e.g. await everyother completion of a barrier from one thread, and all other completions from anotherone. Doing so will lead to deadlocks. To recap: when aBarrier is used to wait insome thread, it must observe every single completion of that barrier (by waiting on it).
Note that theBarrier can receive arrivals from any source, without restrictions.
There are three operations that can complete a barrier:
Asynchronous GMEM-to-SMEM copies#
When an asynchronous GMEM-to-SMEM copy is being executed by the TMA engine, it willpost progress updates to the barrier given toplgpu.copy_gmem_to_smem. Once the copyis complete, the barrier will complete one arrival as well.
Explicit arrival (cross-thread synchronization)#
Any thread can explicitly arrival on a barrier using the following function:
plgpu.barrier_arrive(barrier)
This is especially useful when synchronizing two threads that are in producer/consumerroles. In this case, we recommend allocating two arrays ofBarriers, with size equalto the size of the “queue” used to pass data between the two threads. For example,assume one thread continues writing tiles of an array to SMEM while another threadreads them. We triple-buffer the SMEM region to allow more asynchrony between the twothreads:
tid=jax.lax.axis_index("thread")assertqueue.shape==(buffering,*item_shape)assertproduced.shape==consumed.shape==(buffering,)defthread0_body(i,_):slot=jax.lax.rem(i,buffering)@pl.when(i>=buffering)def_await_consumed():plgpu.barrier_wait(consumed.at[slot])# Wait for consumption of the value before overwriting it# Option 1: Compute the next valuequeue[slot]=produce()plgpu.barrier_arrive(produced.at[slot])# Signal the value is ready# Option 2: Produce the value through async_copy# plgpu.copy_gmem_to_smem(..., queue.at[slot], barrier=produced.at[slot])pl.when(tid==0)(lambda:jax.lax.fori_loop(0,steps,thread0_body,None))defthread1_body(i,_):slot=jax.lax.rem(i,buffering)plgpu.barrier_wait(produced.at[slot])# Wait for the value to be readyconsume(queue[slot])# Load and computeplgpu.barrier_arrive(consumed.at[slot])# Signal that the value is consumedpl.when(tid==1)(lambda:jax.lax.fori_loop(0,steps,thread1_body,None))
Awaitingtcgen05 TensorCore instructions#
Before we begin, an important warning:
Warning
On Blackwell generation of GPUs,Barrier operations by default have relaxedsemantics with respect to the TensorCore operations. This means that by defaultany TensorCore-related operation (including TMEM operation) can be moved by thecompilerafter a barrier signal. Similarly, any TensorCore-related operationcan be movedbefore a barrier wait.
If you mean to useBarriers to indicate to other threads that a TensorCoreoperation is complete, allocate the barrier withorders_tensor_core=True. Thisargument will insert the necessary instructions to prevent the problematicreordering mentioned above.
Unlike in older GPUs, the only way to observe the completion ofBlackwell-generation TensorCore instructions is to pass in aBarrier referenceto theplgpu.tcgen05_mmafunction. Once the MMA is complete, the TensorCore will arrive on the barrier.
Note that this use ofBarriers requires that they are created withorders_tensor_core=True, since they are used to synchronize with TensorCoreoperations.
@functools.partial(pl.run_scoped,barrier_ref=plgpu.Barrier(orders_tensor_core=True))defbarrier_scope(barrier_ref):plgpu.tcgen05_mma(acc_tmem,lhs_ref,rhs_ref,barrier_ref,accumulate=False)plgpu.barrier_wait(barrier_ref)# We can read the result nowresult=plgpu.async_load_tmem(acc_tmem)...
ClusterBarrier#
ClusterBarrier is very similar toBarrier, only used to synchronize acrossblock clusters, instead of threads within a single block. This is alwaysnecessary when the blocks in the cluster collaborate on shared resources.Below we outline some of the more common cases whenClusterBarrier is necessaryto ensure correctness.
Reusing SMEM for collective async copies#
In the following example,ClusterBarrier ensures that both blocks are doneusingx_smem before it is overwritten. Without the barrier, one of the blockswould be able to run ahead and start overwritingx_smem by entering thecollective copy before the other block is done reading from it.
defcollective_smem_reuse(x_gmem,x_gmem2,y_gmem,x_smem,local_barrier,cluster_barrier):plgpu.copy_gmem_to_smem(x_gmem,x_smem,local_barrier,collective_axes="cluster")plgpu.barrier_wait(local_barrier)# x_smem is ready to be used once the local wait completesy_gmem[0]=x_smem[...]plgpu.barrier_arrive(cluster_barrier)plgpu.barrier_wait(cluster_barrier)# x_smem can only be reused once the cluster barrier completesplgpu.copy_gmem_to_smem(x_gmem2,x_smem,local_barrier,collective_axes="cluster")plgpu.barrier_wait(local_barrier)# x_smem is ready to be used once the local wait completesy_gmem[1]=x_smem[...]
Reusing TMEM for collective MMAs on Blackwell#
This example works very similarly to the one before, only this time TMEM is theshared resource. One block issues collective MMAs for both of them, but they bothneed to safely complete a read from TMEM before it can be reused for anothercollective MMA.
defcollective_tmem_reuse(acc_tmem,lhs_ref,rhs_ref,mma_barrier,cluster_barrier):leader_block=lax.axis_index("cluster")==0@pl.when(leader_block)def_do_mma():plgpu.tcgen05_mma(acc_tmem,lhs_ref.at[0],rhs_ref.at[0],mma_barrier,accumulate=False,collective_axis="x",)plgpu.barrier_wait(mma_barrier)do_something(plgpu.async_load_tmem(acc_tmem))plgpu.wait_load_tmem()# Ensure the load is complete.plgpu.barrier_arrive(cluster_barrier)plgpu.barrier_wait(cluster_barrier)# acc_tmem can only be reused once the cluster barrier completes@pl.when(leader_block)def_do_mma():plgpu.tcgen05_mma(acc_tmem,lhs_ref.at[1],rhs_ref.at[1],mma_barrier,accumulate=False,collective_axis="x",)...
Semaphore#
Semaphores are powerful synchronization structures, primarily used tosynchronize across different blocks, potentially running on different devices.For synchronization between threads within a single block, it is preferable touseBarriers, while for cluster synchronization it is preferable to useClusterBarriers. Semaphores are implemented as 32-bit atomic counters located inGMEM that support the following operations:
pl.semaphore_signal,which atomically increments the semaphore. Any effects performed by the threadbefore the signal (including reads or writes to remote memory over NVLINK) areguaranteed to complete before the signal is visible on the target device.pl.semaphore_wait, whichblocks the thread until the semaphore reachesat least the desired value, atwhich point the value is atomically decreased and the thread is awoken. Thefunction can be optionally called withdecrement=False, which will wake thethread as soon as the value is at least the requested value, but the value ofthe semaphore will not be decreased. The non-decrementing version is a bitmore efficient.
Here we present a small example kernel that exchanges two small shards betweentwo devices:
defexchange_shards(x_ref,y_ref,done_sem):other_dev_id=1-lax.axis_index("x")# We assume two devicesneighbor_ref=plgpu.remote_ref(y_ref,other_dev_id)neighbor_ref[...]=x_ref[...]# This will write over NVLINKpl.semaphore_signal(done_sem,device_id=other_dev_id)# Signal that the write is completepl.semaphore_wait(done_sem)# Wait for the other device to write to our memorymesh=jax.make_mesh((2,),("x",))y=jax.jit(jax.shard_map(lambdax:plgpu.kernel(exchange_shards,out_shape=x,scratch_shapes=[plgpu.Semaphore.REGULAR])(x),mesh=mesh,in_specs=P("x"),out_specs=P("x"),check_vma=False,))(x)
Cluster launch control#
Cluster launch controlis a feature introduced in Blackwell GPUs (SM100A+) that enables work stealingor dynamic scheduling of the CUDA grid. This allows an SM(or cluster of SMs) that has finished its work to cancel the launch of blockintended for another SM and execute the work for itself. The end result isthat load balancing across SMs is improved and you should see better utilizationof the GPU towards the tail end of a kernel. Mosaic GPU exposes both thelow-level cluster launch control commands as well as a helper API that abstractsaway most of the implementation details.
Directly using the cluster launch control API#
Mosaic GPU directly exposes the low-level cluster launch control API as twofunctions:plgpu.try_cluster_cancelandplgpu.query_cluster_cancel.try_cluster_cancel is an asynchronous operation that will atomically attemptto cancel the launch of an available block, and place the result in a Ref.The result Ref should be a scratch Ref allocated viaplgpu.TryClusterCancelResult() (which under the hood is a 16-byte SMEM Ref).query_cluster_cancel will read the result and return twovalues: a tuple containing the indices of the grid axes that were requested,and a boolean indicating whether the cancellation was successful. Ifquery_cluster_cancel was not successful, then the result of the grid indicesis undefined and should not be used.
When used with clusters, all blocks within the same cluster will receive thesame result fromquery_cluster_cancel.
The following example demonstrates how to call these with a kernel:
@functools.partial(plgpu.kernel,grid=grid,grid_names=grid_names,scratch_shapes=dict(result_ref=plgpu.TryCancelResultRef(),barrier_ref=plgpu.Barrier()))defkernel(result_ref,barrier_ref):plgpu.try_cluster_cancel(result_ref,barrier_ref)# ... do workplgpu.barrier_wait(barrier_ref)grid_idxs,success=plgpu.query_cluster_cancel(result_ref,grid_names)
Warning
It is important to ensure proper synchronization on all threads throughout thecluster. In most cases when canceling multiple blocks, you may need todouble-buffer the result and barrier to ensure no race conditions occur. Forthis reason we recommend using theplgpu.dynamic_scheduling_loophelper function.
Using theplgpu.dynamic_scheduling_loop helper#
A common pattern when using dynamic work scheduling is to continuously polland execute work within the kernel body until there are no more work left, andthen exit the kernel. Theplgpu.dynamic_scheduling_loophelper function implements exactly this pattern.
@plgpu.dynamic_scheduling_loop(grid_names=grid_names,thread_axis=thread_name# Required if using multiple threads in a kernel.)defbody(loop_info):grid_indices=loop_info.index# ... do work
When using this pattern, the kernel should be instantiated with a gridequal to the logical amount of work to be done (as opposed to a persistentkernel where the grid is set to the number of cores). Each core runningthis loop will continuously query the next available block of work andthe loop will terminate when the entire grid has been scheduled.The signature of the body function is identical to the one used inplgpu.nd_loop (whichis used for normal persistent kernels) and takes in aloop_info dataclassthat contains iteration info, and optionally supports carry values.
Asynchronous copies#
Modern GPUs can directly and asynchronously copy data between GMEM and SMEM withoutinvolving registers. Starting from the Hopper generation, the copies can evenbe offloaded to a special hardware unit called the Tensor Memory Accelerator (TMA),which is what Mosaic uses to implement them.
GMEM to SMEM copies#
To schedule an asynchronous GMEM to SMEM copy, useplgpu.copy_gmem_to_smem. The function takes three operands: a source ref,a destination ref and aBarrier. Once the copy is complete, a single arrival willbe observed on the barrier, as ifplgpu.barrier_arrive(barrier) was called by a background thread:
defbody(in_gmem_ref,out_gmem_ref,smem_ref,barrier):plgpu.copy_gmem_to_smem(in_gmem_ref,smem_ref,barrier)plgpu.barrier_wait(barrier)...plgpu.kernel(body,out_shape=...,scratch_shapes=[plgpu.SMEM(x.shape,x.dtype),plgpu.Barrier()],)
A single barrier can be used to synchronize multiple copies, but it has to beallocated with a higherarrival_count:
defbody(in_gmem_ref,in_gmem_ref2,out_gmem_ref,smem_ref,smem_ref2,barrier):plgpu.copy_gmem_to_smem(in_gmem_ref,smem_ref,barrier)plgpu.copy_gmem_to_smem(in_gmem_ref2,smem_ref2,barrier)plgpu.barrier_wait(barrier)# Awaits both copies...plgpu.kernel(body,out_shape=...,# Barrier is allocated with 2 arrivals.scratch_shapes=[plgpu.SMEM(x.shape,x.dtype),plgpu.Barrier(num_arrivals=2)],)
Collective copies#
When using block clusters, the asynchronous transfers feature amulticast option,meaning that multiple blocks from the cluster can collectively load the same input.In some sense, this can be seen as a guaranteed L2 hit for all participating blocks,as it allows for better sharing of the limited HBM bandwidth.
Warning
When using collective copies, all blocks along the specified cluster axes mustissue the same collective copy for the program to be valid. It is not allowed toonly issue it from one block but not from others and it will result in undefinedbehavior (most likely a deadlock).
Warning
When using collective copies, you need to be extra careful about reusing the SMEMbuffers. The different blocks in the cluster might finish using them at differentpoints in time but the first block that issues the next collective copy can overwritethe data still used by other blocks. See theClusterBarrier sectionfor examples for how to make this safe.
defbody(in_gmem_ref,in_gmem_ref2,out_gmem_ref,smem_ref,smem_ref2,barrier):block_id=lax.axis_index("cluster")# Both blocks in the cluster load the same data into smem_ref, so we can use# a collective copy here.plgpu.copy_gmem_to_smem(in_gmem_ref,smem_ref,barrier,collective_axes="cluster")# Each block in the cluster loads a different slice of in_gmem_ref2, so we# are not allowed to use collective copies.plgpu.copy_gmem_to_smem(in_gmem_ref2.at[block_id],smem_ref2,barrier)plgpu.barrier_wait(barrier)# Awaits both copies...plgpu.kernel(body,out_shape=...,# Barrier is allocated with 2 arrivals.scratch_shapes=[plgpu.SMEM(x.shape,x.dtype),plgpu.Barrier(num_arrivals=2)],)
Collective partitioned copies (Blackwell only)#
In the Blackwell generations, collective copies that involve clusters of twoblocks can bepartitioned by passing an additionalpartitioned_axis argument.When specified, the GMEM reference is expected to be double the size of thedestination SMEM reference along the specified dimension. The destination in thefirst block will be overwritten with the first half of the GMEM ref, while thesecond block will receive the second half.
This by itself would be equivalent to performing two non-collective copies ondifferent input slices, but there’s one crucial difference: only the barrier inthe first block will receive the arrival once both copies complete. The barrierargument in the second block is ignored and the second block cannot use it toawait the completion of the transfer.
Arguably, this is a bit of a surprising feature, but it makes sense in thecontext of collective MMAs on Blackwell. There, each block is responsible forloading the operands into SMEM, but only the first block awaits thecompletion of the transfers and issues the MMA instructions. The second blockusually waits on the completion of the MMA to indicate that the transfer is done,and the SMEM data has been read out, implying that it can safely overwrite it.
SMEM to GMEM copies#
To schedule an asynchronous GMEM to SMEM copy, useplgpu.copy_smem_to_gmem. As opposed to the other direction, this primitiveonly takes in the source and destination references. To await the completion ofthe copy, use theplgpu.wait_smem_to_gmem.
The synchronization scheme for SMEM to GMEM copies is a little unexpected in thatthey cannot be awaited in arbitrary orders.plgpu.wait_smem_to_gmem takes asan argument the number of most recent copiesyou do not want to await, or equivalentlythe number of asynchronous SMEM to GMEM copies that you still want to allowto run:
defcopy_out(x_smem,y_smem,x_gmem,y_gmem):plgpu.copy_smem_to_gmem(x_smem,x_gmem)plgpu.copy_smem_to_gmem(y_smem,y_gmem)plgpu.wait_smem_to_gmem(1,wait_read_only=True)# At this point we know that the data of x_smem has been read, but we don't# yet know that x_gmem contains the updated data.plgpu.wait_smem_to_gmem(1)# At this point we know that the x_smem -> x_gmem copy is done, but we know# nothing about the y_smem -> y_gmem copy.plgpu.wait_smem_to_gmem(0)# At this point we know that both copies are complete.
Note that an SMEM to GMEM copy can only ever be awaited in the same thread thathas issued it.wait_smem_to_gmem returns immediately if no copies have beenissued or they have all completed.
Only awaiting the read from SMEM#
Another option is that you can either await the copy being committed to GMEMYou can choose to wait until the copy is fully written into GMEM(in a way that will be visible to following reads), or you can only await thedata being read from SMEM by specifyingwait_read_only in the wait function.This allows for a faster reuse of SMEM buffers if you don’t intend to read backthe data sent to GMEM just yet.
Grouping multiple copies#
Whencopy_smem_to_gmem receivescommit_group=False as an argument, it cannotbe awaited untilplgpu.commit_groupis called explicitly, or anothercopy_smem_to_gmem without that argument is issued.All SMEM to GMEM copies since the last commit are grouped together as a single awaitable unit:
defcopy_out(x_smem,y_smem,x_gmem,y_gmem):plgpu.copy_smem_to_gmem(x_smem,x_gmem,commit_group=False)plgpu.copy_smem_to_gmem(y_smem,y_gmem)# Implicitly commits both copiesplgpu.wait_smem_to_gmem(1)# At this point we only know that no SMEM to GMEM copies other than the two# above are active.plgpu.wait_smem_to_gmem(0)# Only now we know that both copies above have completed.
Asynchronous gathers#
On Blackwell GPUs, the TMA engine has an additional mode that allows for an efficientimplementation of gathers along the first dimension on a 2D matrix. Using thismode is actually very simple. The 1D array of indices should be loaded intoaplgpu.Layout.TMA_GATHER_INDICES layout, and the source GMEM referencehas to be indexed with that array using the.at operator:
@functools.partial(self.pallas_call,out_shape=jax.ShapeDtypeStruct(out_shape,dtype),out_specs=plgpu.BlockSpec(memory_space=plgpu.SMEM,transforms=transforms),in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),pl.BlockSpec(memory_space=plgpu.SMEM),),scratch_shapes=[plgpu.Barrier()],)defkernel(x_ref_gmem,idx_ref,o_ref,barrier_ref):idxs=plgpu.load(idx_ref,(),layout=plgpu.Layout.TMA_GATHER_INDICES)plgpu.copy_gmem_to_smem(x_ref_gmem.at[idxs],o_ref,barrier_ref)plgpu.barrier_wait(barrier_ref)
Theplgpu.copy_gmem_to_smem automatically recognizes that the reference hasbeen sliced with an array and will use the gather TMA instructions to implementthe copy.
NVLINK transfers#
Asynchronous copies in either direction support GMEM references returned fromplgpu.peer_ref, which makes it possible to perform NVLINK transfers asynchronously.
defexchange_shards(x_ref,y_ref,smem_ref,local_barrier,done_sem):plgpu.copy_gmem_to_smem(x_ref,smem_ref,local_barrier)# Local copyplgpu.barrier_wait(local_barrier)other_dev_id=1-lax.axis_index("x")# We assume two devicesneighbor_ref=plgpu.remote_ref(y_ref,other_dev_id)plgpu.copy_smem_to_gmem(smem_ref,neighbor_ref)plgpu.wait_smem_to_gmem(0)# Wait for the asynchronous write to completepl.semaphore_signal(done_sem,device_id=other_dev_id)# Signal that the write is completepl.semaphore_wait(done_sem)# Wait for the other device to write to our memorymesh=jax.make_mesh((2,),("x",))y=jax.jit(jax.shard_map(lambdax:plgpu.kernel(exchange_shards,out_shape=x,scratch_shapes=[x,plgpu.Barrier(),plgpu.Semaphore.REGULAR])(x),mesh=mesh,in_specs=P("x"),out_specs=P("x"),check_vma=False,))(x)
Inline Mosaic GPU#
TODO
Compiler parameters#
TODO
Debugging#
Mosaic GPU exposes a number of environment variables to diagnose issues with thegenerated low-level code:
MOSAIC_GPU_DUMP_PTXASallows dumping the compilation logs fromptxastostandard output when set;MOSAIC_GPU_DUMP_PTXallows dumping the PTX code generated during compilationto standard output when set;MOSAIC_GPU_DUMP_MLIR_PASSESallows dumping the IR after every MLIR passin the compilation pipeline to standard output;MOSAIC_GPU_DUMP_SASSallows dumping the SASS code produced at the end ofcompilation to standard output;MOSAIC_GPU_DUMP_SASS_CTRLallows dumping the SASS control codes followingNervanaSystems/maxas to standardoutput;MOSAIC_GPU_DUMP_TOallows specifying a directory path (that must exist)where all of the above will be dumped as files.MOSAIC_GPU_LLVM_DEBUG_ONLYallows specifying a comma-separated list ofLLVM debug types,in order to produce relevant LLVM debugging logs. This environment variable isonly available in debug builds (i.e. builds withoutNDEBUG).MOSAIC_GPU_DUMP_LLVMallows dumping LLVM IR when set. It is equivalent tosettingMOSAIC_GPU_LLVM_DEBUG_ONLY=serialize-to-llvm, and both environmentvariables compose. LikeMOSAIC_GPU_LLVM_DEBUG_ONLY, this environmentvariable is only available in debug builds.
Calling kernels from PyTorch#
Theplgpu.as_torch_kerneldecorator wraps a Pallas:MGPU kernel to allow invoking it with PyTorch tensors.It accepts CUDA tensors as inputs and returns newly allocated CUDA tensorson the same device.
Example:
importfunctoolsimportjaximportjax.numpyasjnpimporttorch@functools.partial(pl.pallas_call,out_shape=jax.ShapeDtypeStruct([128],jnp.int32))defadd_kernel(x_ref,y_ref,o_ref):o_ref[...]=x_ref[...]+y_ref[...]x=torch.arange(128,dtype=torch.int32,device="cuda")y=x*xout=plgpu.as_torch_kernel(add_kernel)(x,y)
plgpu.as_torch_kernel only supports functions that contain a single kernelinvocation (e.g. viapl.pallas_call orplgpu.kernel), and no calls toother JAX operations, e.g. fromjax.numpy.
