Collective matrix multiplication
Contents
Collective matrix multiplication#
Tensor parallelism (TP) and data parallelism (DP) are the most frequently usedparallelism techniques that make it possible to fit the ever larger models ontoa number of accelerators. However, their joint use means that in our programs,we sometimes end up with data sharded in ways that don’t make it directlypossible to execute an operation without additional communication. One suchproblem frequently happens at the beginning of the MLP block of a Transformer.There, the input activations might be sharded on the batch axis (DP), while theweights might be partitioned on the output feature dimension (TP).
The contraction dimension is not sharded, so it might seem that we can justmultiply the inputs, but there is a problem: the output can’t be sharded alongthe same device axis on both of its dimensions!
There’s a simple way to solve this problem: we can all-gather activations orweights (here we focus on the activation side), and then perform a local matrixmultiplication with the other operand sharded. This simple strategy works, butit has a downside: we can’t begin computing the matrix multiplication while theall-gather is running! That means we’re underutilizing our hardware!
To achieve better utilization, we’ll show how simple it is to implement aPallas:MGPU kernel that overlaps the cross-device communication with thematrix-multiplication, achieving almost optimal utilization on large enoughproblem shapes. Our implementation makes heavy use of the NVLINK interconnect,which allows us to perform high-bandwidth inter-GPU communication withoutinvolving the host.
This approach already yields considerable performance improvements! If weconsider a f16 matmul with M=1024, K=4096 and N=4096 and normally distributeddata, our benchmarks indicate that it should take about 43us on a single H100.In the table below, we scale up the M dimension so that the per-shard shape isM=1024. We can compute an expected lower bound for the execution of ourdistributed kernel by multiplying that local runtime estimate by the number ofdevices and by adding about 6us for each round of communication (the memoryfences associated with the synchronization are expensive). Benchmarking ourkernel yields the following results:
Device count | Kernel time | TC utilization | Lower bound | TC utilization | Reference time | TC utilization |
|---|---|---|---|---|---|---|
2 | 102us | 68% | 92us | 75% | 147us | 47% |
4 | 212us | 66% | 190us | 73% | 290us | 48% |
8 | 436us | 64% | 386us | 72% | 565us | 49% |
As you can see there are still some opportunities for optimization here, but atleast we’re getting much better utilization compared to the baselineimplementation of a NCCL all gather and cuBLAS matmul.
Algorithm overview: Ring All-Gather#
To computeAllGather(A)@B, we form a ring on the participatingD devices.At each step, the device takes the last received shard (starting from its localshard), and passes it to the next device in the ring. While the send ishappening, we compute the matrix multiplication between the last receivedA shardand the localB shard.
More formally, the algorithm proceeds inD steps. In stepi (0<=i<D),deviced receives shardA_{(d+i)%D} (we don’t actually receive in thefirst step) from device(d+1)%D, computesA_{(d+i)%D}@B_d, andwrites the result to a slice of the output buffer. Concurrently with thecompute, the deviced sends shardA_{(i+d)%D} to device(i-1)%Dfor its use in stepi+1 (we don’t send in the last step). AfterD steps,deviced will have seen every shard ofA and computed the full output.
Pallas primitives for inter-device communication#
We use three Pallas functions for inter-device communication:
plgpu.remote_ref(ref,device_id): This function takes a reference to abuffer in global memory (GMEM) and returns a reference to the same buffer on adifferent device, specified bydevice_id. When communicating over NVLINK,this reference can be read or written to directly, even though its data is locatedin remote memory.pl.semaphore_signal(sem,device_id=...): Increments a semaphore on atarget device. This is usually used to indicate completion of some process,such as when we notify the remote device that the data it’s waiting for hasbeen sent.pl.semaphore_wait(sem,value=...,decrement=...): Blocks until a localsemaphore reaches a certain value. If decrement isTrue(default), thevalue of the semaphore is decreased by the awaited amount. If it isFalse,the operation is more efficient, but it does not modify the value of thesemaphore after the wait completes. This is frequently used to await signalsfrom a remote device.
Implementation with Pallas#
Note
Here, we only present a simplified version of the kernel, which allows us tofocus on the most interesting details. You can findthe full implementation inour examples directory.
First, we focus on the set-up of our kernel. For the compute part, we will reuseour optimized matmul kernel implementation fromhopper_matmul_mgpu. Since thecompute kernel will utilize warp-specialization, we use 3 Pallas threads. Itis also persistent, which means that we launch a grid as large as the number ofSMs (queried from.core_count on the JAX device). The compute kernel usespl.run_scoped for SMEM allocations, so we don’t usescratch_shapes.
defall_gather_lhs_matmul(lhs:jax.Array,rhs:jax.Array,axis_name,*,config:hopper_matmul_mgpu.TuningConfig,dtype:jnp.dtype=jnp.bfloat16,)->jax.Array:if(num_devices:=jax.device_count())!=jax.process_count():raiseValueError("The kernel only supports one device per process")if(axis_size:=lax.axis_size(axis_name))!=num_devices:raiseValueError("The kernel can only work over all devices in a Mesh.")...m_shard,k=lhs.shape_,n_shard=rhs.shapetile_m,tile_n,tile_k=config.tile_m,config.tile_n,config.tile_kcta_tile_m=tile_m*(1+(config.wg_dimension==MatmulDimension.M))num_sms=jax.extend.backend.get_default_device().core_countdefkernel_body(lhs_local_ref,rhs_ref,out_ref,scratch_ref):...result,_=plgpu.kernel(kernel_body,out_shape=[# The output (with M gathered)jax.ShapeDtypeStruct((axis_size*m_shard,n_shard),dtype),# A scratch buffer for LHS all-gatherjax.ShapeDtypeStruct((axis_size-1,m_shard,k),dtype),],grid=(num_sms,),num_threads=3,# The matmul kernel uses 3 threads: 2 compute and 1 memorythread_name="wg",)(lhs,rhs)returnresult
The kernel above has two outputs. First one is the actual result of ourprimitive, while the second one is used as a scratch space to receive the leftoperands. Note that we could shrink the leading axis to be smaller thanaxis_size-1, but at that point we would need to introduce backpressure tothe sending devices, which requires additional expensive communication.
Note
You can see how to deal with this backpressure in theTPU distributed communication guide.
Let us now look at the outline of the kernel body:
defall_gather_lhs_matmul(...):defkernel_body(lhs_local_ref,rhs_ref,out_ref,scratch_ref,out_smem,received_sem):wg_idx=lax.axis_index("wg")dev_id=lax.axis_index(axis_name)# This device sends to dev_id - 1, forming a ring.send_dev_id=lax.rem(dev_id+axis_size-1,axis_size)send_scratch_ref=plgpu.remote_ref(scratch_ref,send_dev_id)defdevice_step(lhs_source_ref,device_offset):# Invariant: lhs_source_ref contains A_{(dev_id + device_offset) % D}# and is ready to be used for computation....# We peel the first step to read data directly from lhs_local_ref.device_step(lhs_local_ref,0)@pl.loop(1,num_devices)def_device_loop(device_offset):device_step(scratch_ref.at[device_offset-1],device_offset)
We locate our position in the ring by queryinglax.axis_index(axis_name) andcompute the index of the next device, to which we will be sending the data(send_dev_id). Then, we loop over invocations of thedevice_body as manytimes as there are devices. We peel the first step of the loop, because we usethe local reference as the source for the send in that step only (after that thesends originate from the data previously received in the scratch buffer).
We are ready to investigate the main loop now:
defall_gather_lhs_matmul(...):...defkernel_body(lhs_local_ref,rhs_ref,out_ref,scratch_ref,out_smem,received_sem):...defdevice_step(lhs_source_ref,device_offset):# We are computing block (dev_id + device_offset) % D of the output.out_device_idx=lax.rem(device_offset+dev_id,axis_size)out_device_m_slice=pl.ds(out_device_idx*m_shard,m_shard)# In step `device_offset`, we send A_{(dev_id + device_offset) % D} to# the next device in the ring, into scratch slot `device_offset`.# We also don't send on the last step since that would return the data# back to its original source.next_scratch_slot=device_offsetis_send_wg=wg_idx==0# Only one warpgroup per CTA sendshas_send_space=next_scratch_slot<axis_size-1should_send=is_send_wg&has_send_space# This function will be called by hopper_matmul_mgpu.kernel in the body# of its pipeline. We use it to take the tile of LHS loaded into SMEM and# issue a TMA send to the next device in the ring.defsend_lhs(m_idx,n_idx,k_idx,a_smem,b_smem,send_ref,should_send):delb_smem# Unused.# We only send when n_idx == 0 to avoid sending the same data# multiple times when revisiting the left operand.@pl.when(should_send&jnp.bool(n_idx==0))def_():k_slice=pl.ds(k_idx*tile_k,tile_k)m_slice=pl.ds(m_idx*cta_tile_m,cta_tile_m)plgpu.copy_smem_to_gmem(a_smem,send_ref.at[m_slice,k_slice])# Wait for previous copies to complete. We pass in delay_release=1# to the pipeline in the matmul kernel to ensure that it doesn't# overwrite the input until at least the next step completes, but it# will not wait any longer.plgpu.wait_smem_to_gmem(1,wait_read_only=True)hopper_matmul_mgpu.kernel(lhs_source_ref,# LHS shard for this steprhs_ref,# RHS shard is always the sameout_ref.at[out_device_m_slice],# Slice of output to updateout_smem,config=config,pipeline_callback=functools.partial(send_lhs,send_ref=send_scratch_ref.at[next_scratch_slot],should_send=should_send,),delay_release=1,)# Wait for the next scratch to arrive for the next step's computation.# Each device signals its neighbor when it has finished sending.@pl.when(should_send)def_signal():# Make sure our remote copy is done, then signal.plgpu.wait_smem_to_gmem(0,wait_read_only=False)pl.semaphore_signal(received_sem,device_id=send_dev_id)@pl.when(has_send_space)def_wait():# Here, we wait for the data to arrive from the previous device in the# ring. At each step, will expect to receive a signal from each SM.# We use decrement=False to make this operation slightly faster, but# this also means that we need to scale the expected number of signals# by the number of steps taken so far (as the value only increases).pl.semaphore_wait(received_sem,value=(device_offset+1)*num_sms,decrement=False)...
A few things happen here in a sequence:
We begin by computing the slice of theoutput that we will compute at this step of the loop.
Then, we call into the optimized matmul kernel, but injecting it with a
pipeline_callback. We use it to take advantage of the fact that the computekernel has to fetch the left operand into SMEM, and we instruct the TMA engineto asynchronously stream the local data to the next device. The traffic istransparently routed through NVLINK by the hardware. It is worth noting that weonly issue sends from one of the compute threads and only when we visit the leftoperand for the first time (it might be reloaded many times to compute manyoutput tiles).Finally, the sending thread makes sure that the sends have completed andsignals the
received_semon the receiving device to indicate that. Afterthat, all threads wait until they are sure that all the data for the nextstep of the loop has been received (the wait is skipped on the last step).
Integrating the kernel with JAX#
To invoke the kernel, you need to wrap it intojax.shard_map:
m_shard,n_shard,k=1024,1024,1024dtype=jnp.float16mesh=jax.make_mesh((jax.device_count(),),("x",),axis_types=(jax.sharding.AxisType.Explicit,))withjax.set_mesh(mesh):a=jax.random.normal(jax.random.key(1),(m_shard*jax.device_count(),k),dtype)b=jax.random.normal(jax.random.key(2),(k,n_shard*jax.device_count()),dtype)a=jax.sharding.reshard(a,P("x",None))b=jax.sharding.reshard(b,P(None,"x"))# Example config for 8xH100. You might need to retune to your shape.config=hopper_matmul_mgpu.TuningConfig(tile_m=128,tile_n=128,tile_k=64,max_concurrent_steps=4,grid_minor_dim=MatmulDimension.N,grid_tile_width=8,wg_dimension=MatmulDimension.N,)kernel=jax.jit(jax.shard_map(functools.partial(all_gather_lhs_matmul,axis_name="x",config=config),out_specs=P(None,"x"),check_vma=False,))c=kernel(a,b)
