Pallas Async Operations
Contents
Pallas Async Operations#
Background + Motivation#
We’d like to expose APIs in Pallas to explicitly overlap computation and communicationacross multiple kernels.
XLA Async Decomposition#
As motivation, consider the following JAX pseudocode:
deff(x):y=ppermute(x)z=x+1returny,z
In this function, we could perform theppermute at the same time as thex+1. This is an optimization XLA does automatically by:
decomposing
ppermuteinto appermute_startandppermute_doneop, which are connected via a future.scheduling the
x+1between theppermute_startandppermute_done,
resulting in the following program:
deff(x):fut=ppermute_start(x)z=x+1# happens at the same time as ppermutey=ppermute_done(fut)returny,z
Async ops inside kernels#
Now imagine we aren’t using XLA’sppermute but have our own custom Pallasppermute.
defppermute_kernel(x_ref,y_ref,send_sem,recv_sem):right_neighbor=...descriptor=pltpu.make_async_remote_copy(x_ref,y_ref,send_sem,recv_sem,device_id=right_neighbor)descriptor.start()descriptor.wait_send()descriptor.wait_recv()defppermute(x):returnpl.pallas_call(ppermute_kernel,out_shape=x,...)(x)
Currently, we cannot decomposeppermute into astart/done pair as XLA does, so instead we explicitlyfuse thex+1 into the kernel.
defadd_one(x_ref,z_ref):z_ref[...]=x_ref[...]+1defppermute_add_one_kernel(x_ref,y_ref,z_ref,send_sem,recv_sem):right_neighbor=...descriptor=pltpu.make_async_remote_copy(x_ref,y_ref,send_sem,recv_sem,device_id=right_neighbor)descriptor.start()# Explicitly schedule inner kernel between start/waitpltpu.emit_pipeline(add_one)(x_ref,z_ref)descriptor.wait_send()descriptor.wait_recv()defppermute_and_add_one(x):returnpl.pallas_call(ppermute_add_one_kernel,out_shape=(x,x),...)(x)
The goal is to enable writing separate kernels for starting theppermute and waiting on it to complete, so that we can use a regular oldx+1 in between (or whatever compute we want). This makes the code more readable, maintainable, and less bug-prone.
How do we implement decomposed Pallas async operations (on TPU)?#
The main thing to figure out when implementing decomposed async operations in Pallas is what thefuture that is passed between them contains. Specifically, it must contain some important state about the operation happening in the background.
If we look at the Pallas code, we can see that we need a “descriptor” to both start and wait on a remote copy. Can we plumb this descriptor out of the Pallas kernel, and then pass it into another one? Well kinda. The underlying TPU hardware tracks async op progress via a pair of semaphores:send_sem enables us to wait on when a device is done sending data to its neighbor andrecv_sem tracks the data transfer sent to a device from their neighbor. If we imagine writing a start kernel and a done kernel, all we’d need to pass from the start to the done would be the semaphores and some information about how much to wait on those semaphores.
We can do this via extending Pallas to support returning semaphores from kernels.
defppermute_start_kernel(in_ref,send_sem,recv_sem,out_ref,*,axis_name,):axis_size=jax.lax.psum(1,axis_name)left_neighbor=jax.lax.rem(jax.lax.axis_index(axis_name)-1+axis_size,axis_size)right_neighbor=jax.lax.rem(jax.lax.axis_index(axis_name)+1,axis_size)barrier_sem=pltpu.get_barrier_semaphore()pltpu.semaphore_signal(barrier_sem,device_id=left_neighbor)pltpu.semaphore_wait(barrier_sem,1)pltpu.make_async_remote_copy(in_ref,out_ref,send_sem,recv_sem,device_id=right_neighbor).start()defppermute_start(x,*,axis_name)->tuple[Semaphore,Semaphore,Array]:send_sem,recv_sem,out=pl.pallas_call(functools.partial(ppermute_start_kernel,axis_name=axis_name),out_shape=(pltpu.SemaphoreType.DMA(()),pltpu.SemaphoreType.DMA(()),jax.ShapeDtypeStruct(x.shape,dtype=x.dtype,),),in_specs=[pl.BlockSpec(memory_space=pltpu.ANY),],out_specs=(pl.BlockSpec(memory_space=pltpu.SEMAPHORE),pl.BlockSpec(memory_space=pltpu.SEMAPHORE),pl.BlockSpec(memory_space=pltpu.ANY),),)(x)returnsend_sem,recv_sem,out
Note that something subtle is happening here. Pallas is telling XLA that it would like some outputs to be semaphores (a.k.a. sync flags) and XLA will treat them as “reserved” (e.g. while they are alive in the XLA program, those sync flags cannot be allocated by other kernels). They behave similarly to barrier semaphores, which are reserved semaphores managed by XLA.
Another thing to notice is that we return the output bufferout from the start kernelwhile it’s being actively copied into.
Now we write thedone kernel that performs the blocking operation. We passout into the kernel to compute the shape needed to block on the semaphore.
defppermute_done_kernel(ref,send_sem,recv_sem,_):pltpu.make_async_copy(ref,ref,send_sem).wait()pltpu.make_async_copy(ref,ref,recv_sem).wait()defppermute_done(send_sem,recv_sem,out)->Array:out=pl.pallas_call(ppermute_done_kernel,out_shape=(jax.ShapeDtypeStruct(out.shape,dtype=out.dtype,),),in_specs=[pl.BlockSpec(memory_space=pltpu.ANY),pl.BlockSpec(memory_space=pltpu.SEMAPHORE),pl.BlockSpec(memory_space=pltpu.SEMAPHORE),],out_specs=pl.BlockSpec(memory_space=pltpu.ANY),input_output_aliases={0:0})(out,send_sem,recv_sem)returnout
Note: we i/o alias the output buffer here to guarantee that the consumers are downstream of theppermute_done.
We now can implement the decomposed collective permute.
deff(x):fut=ppermute_start(x)z=x+1# happens at the same time as ppermutey=ppermute_done(fut)returny,z
OR CAN WE?
Whydoesn’t this work?#
There are three remaining issues with this, each of which exists outside of Pallas to some degree. Here they are at a high level.
Scheduling - just because we write
ppermute_start, thenx+1, thenppermute_donedoesn’t guarantee that they will happen in that order. XLA is responsible for scheduling, so when we write JAX programs, we are setting up data dependencies that XLA will respect but XLA will not respect the specific order of operations written in JAX.Lifetimes - XLA assumes that once a value is out of scope in the dependency graph, its memory can be freed for use by other values. If we have an op that asynchronously copies x -> y, we need to ensure that x is alive until the copy is complete, otherwise we will be copying from garbage memory.
Defensive copies - XLA reserves the right to create copies of values. We need to make sure we don’t introduce unnecessary copies to a) avoid unnecessary runtime overhead and b) ensure correctness.
We will go over these issues one by one and suggest fixes.
Scheduling#
How do we explicitly force ops to happen in a particular order in JAX? Note that this is not a Pallas specific problem, and if we had async ops implemented using an alternative method, we’d still run into this.
One way is to introduce an optimization barrier into the XLA program. The optimization barrier will prevent XLA moving ops around it.
Here’s our original code:
deff(x):fut=ppermute_start(x)z=x+1y=ppermute_done(fut)returny,z
XLA could choose to executex+1 in any of three places:
deff(x):z=x+1fut=ppermute_start(x)y=ppermute_done(fut)returny,z# ORdeff(x):fut=ppermute_start(x)z=x+1y=ppermute_done(fut)returny,z# ORdeff(x):fut=ppermute_start(x)y=ppermute_done(fut)z=x+1returny,z
To force thex+1 to happen between theppermute ops, we can useoptimization_barrier, which is semantically the identity function (i.e.lambdax:x) but introduces an explicit data dependency between values. Specifically, if we make thex that is used inx+1 dependent on thefut returned byppermute_start, it must happen afterppermute_start.
We also introduce a dependency that forces the output valuey to depend onz.
deff(x):fut=ppermute_start(x)x,fut=optimization_barrier((x,fut))# x now depends on futz=x+1z,fut=optimization_barrier((z,fut))# fut now depends on zy=ppermute_done(fut)returny,z
optimization_barrier is a good enough hammer for us to explicitly write out schedules.
Lifetimes#
Let’s look at our original code again and assume the ops are happening in the correct order.
deff(x):fut=ppermute_start(x)z=x+1y=ppermute_done(fut)returny,z
Let’s look at which point in the program XLA believes it is okay to free the buffer forx. It would be the point after whichx is no longer used, specifically afterz=x+1.
deff(x):fut=ppermute_start(x)z=x+1# XLA can free x here!y=ppermute_done(fut)returny,z
If XLA freesx afterz=x+1 has completed, we run into a very bad problem. Theppermute could still be actively copyingx to the neighbor afterz=x+1 which means ifx is freed, theppermute will be reading from garbage memory!
How do we extendx’s lifetime to theppermute_done? Well we can introduce a data dependency! We need to modify our kernels a little bit to make this happen.
First, we rewriteppermute_start to returnx, aliasing it through the kernel.
defppermute_start_kernel(in_ref,send_sem,recv_sem,out_ref,_,*,axis_name,):axis_size=jax.lax.psum(1,axis_name)left_neighbor=jax.lax.rem(jax.lax.axis_index(axis_name)-1+axis_size,axis_size)right_neighbor=jax.lax.rem(jax.lax.axis_index(axis_name)+1,axis_size)barrier_sem=pltpu.get_barrier_semaphore()pltpu.semaphore_signal(barrier_sem,device_id=left_neighbor)pltpu.semaphore_wait(barrier_sem,1)pltpu.make_async_remote_copy(in_ref,out_ref,send_sem,recv_sem,device_id=right_neighbor).start()defppermute_start(x,*,axis_name)->tuple[Semaphore,Semaphore,Array,Array]:send_sem,recv_sem,x,out=pl.pallas_call(functools.partial(ppermute_start_kernel,axis_name=axis_name),out_shape=(pltpu.SemaphoreType.DMA(()),pltpu.SemaphoreType.DMA(()),jax.ShapeDtypeStruct(x.shape,dtype=x.dtype,),jax.ShapeDtypeStruct(x.shape,dtype=x.dtype,),),in_specs=[pl.BlockSpec(memory_space=pltpu.ANY),],out_specs=(pl.BlockSpec(memory_space=pltpu.SEMAPHORE),pl.BlockSpec(memory_space=pltpu.SEMAPHORE),pl.BlockSpec(memory_space=pltpu.ANY),pl.BlockSpec(memory_space=pltpu.ANY),),input_output_aliases={0:2})(x)returnsend_sem,recv_sem,x,out
We then haveppermute_done take inx and do nothing with it.
defppermute_done_kernel(_,ref,send_sem,recv_sem,_):pltpu.make_async_copy(ref,ref,send_sem).wait()pltpu.make_async_copy(ref,ref,recv_sem).wait()defppermute_done(send_sem,recv_sem,x,out)->Array:out=pl.pallas_call(ppermute_done_kernel,out_shape=(jax.ShapeDtypeStruct(out.shape,dtype=out.dtype,),),in_specs=[pl.BlockSpec(memory_space=pltpu.ANY),pl.BlockSpec(memory_space=pltpu.ANY),pl.BlockSpec(memory_space=pltpu.SEMAPHORE),pl.BlockSpec(memory_space=pltpu.SEMAPHORE),],out_specs=pl.BlockSpec(memory_space=pltpu.ANY),input_output_aliases={1:0})(x,out,send_sem,recv_sem)returnout
Now when we write
deff(x):*sems,x,out=ppermute_start(x)z=x+1y=ppermute_done(*sems,x,out)returny,z
XLA can no longer freex because it is an input toppermute_done! This means thatx’s lifetime is tied to theppermute and this code is now correct.
Defensive copies#
XLA, in its buffer assignment pass, analyzes which buffers are aliased to each other and inserts copies whenever an operation that aliases one of its inputs is not the final consumer of that input.
Background#
Here’s a simple example. Let’s say we have an opadd_one_inplace which takes in an array and adds one, but promises to do it in-place.
The following code would be legal.
deff():x=jnp.arange(...)y=add_one_inplace(x)returny
However, ifx had a separate consumer as well, the program may not execute correctly.
deff():x=jnp.arange(...)y=add_one_inplace(x)returny,x*2# another x consumer!
This is becausex*2 operates on the originalx butadd_one_inplace clobbers the value inx.x*2 needs to make sure to read the original values ofx, not the ones after we’ve incremented it by 1. XLA notices this and inserts acopy op (which is semantically the identity but the input and output buffers will be different).
deff(x):x2=copy(x)y=add_one_inplace(x2)returny,x*2
This pass in XLA ensures correctness in the presence of ops that perform in-place updates by forcing them to effectively be out-of-place withcopy ops.
Copies with downstream ops#
Let’s revisit our example where we add 1 whileppermuteing.
deff(x):fut=ppermute_start(x)z=x+1y=ppermute_done(fut)returny,z
If we unpack the future into its components, we’ll see the the aliasing patterns:
deff(x):*sems,x2,y=ppermute_start(x)z=x+1y=ppermute_done((*sems,x2,y))returny,z
We know thatx is left unchanged byppermute_start (that is,x is identical tox2), but XLA does not. In fact, it looks like ouradd_one_inplace example to XLA, where it conservatively assumes thatppermute_start mutatedx andx2 is the new aliased result. Therefore, when we doz=x+1, we run into a consumer of the original buffer. XLA therefore introduces a copy!
deff(x):x2=copy(x)*sems,x2,y=ppermute_start(x2)z=x+1y=ppermute_done((*sems,x2,y))returny,z
This copy is unnecessary because we know thatx2 is unchanged fromx. In order to remove this copy, we’d need some mechanism to inform XLA we are just forwarding a value. However, in the absence of that we can rewrite our program a bit to explicitly usex2 instead ofx.
deff(x):*sems,x2,y=ppermute_start(x)z=x2+1y=ppermute_done((*sems,x2,y))returny,z
Now, XLA doesn’t see a separate consumer ofx so no more copy is introduced. However, this comes at a major downside in that it forces us to unpack the future coming fromppermute_start. It couples the lifetime problem to the copying problem.
Loop aliasing#
Let’s consider a slightly more advanced example. Let’s implement a function that uses awhile_loop withppermute to send values around a ring.
deff(x):defbody(i,x):fut=ppermute_start(x)y=ppermute_done(fut)returnyreturnfori_loop(0,8,body,x)
One implementation detail offori_loop is that the inputs and outputs buffers are automatically aliased to each other. Note that we are setting up some additional aliasing in theppermute_start andppermute_done ops. Let’s run our own “buffer assignment” by coloring each of the values in the program to determine how many unique buffers we need.
First, we’ll unpack thefut tuple that has the aliasedx andout buffers.
deff(x):defbody(i,x):*sems,x,y=ppermute_start(x)y=ppermute_done(*sems,x,y)returnyreturnfori_loop(0,8,body,x)
Let’s now color each of the values according to the unique buffer they are assigned. We have the input/output aliasing coming fromfori_loop, thex aliasing coming fromppermute_start and they aliasing coming fromppermute_done.
deff(x):defbody(i,x):*sems,x,y=ppermute_start(x)y=ppermute_done((*sems,x,y))returnyreturnfori_loop(0,8,body,x)
If you run the alias analysis, you’ll find that all of the buffers have been colored the same! Intuitively, this is problematic because if we are doing a loop ofppermutes, we can’t write into the same buffer we are sending into. We generally need an extra (i.e. a “double”) buffer to receive, and then usually we will switch the send/recv buffers on the next iteration. What XLA will do in practice is that it will observe the buffer reuse and defensively insert a copy.
deff(x):defbody(i,x):x=copy(x)*sems,x,y=ppermute_start(x)y=ppermute_done((*sems,x,y))returnyreturnfori_loop(0,8,body,x)
This copy meansx andy are no longer aliased to each other and the program will be correct. However, do we need this copy? How do we introduce a double buffer to avoid expensive copies each iteration? The answer is unrolling!
We’ll manually unroll our code.
deff(x):defbody(i,x):*sems,x,x2=ppermute_start(x)x2=ppermute_done((*sems,x,x2))*sems,x2,y=ppermute_start(x2)y=ppermute_done((*sems,x2,y))returnyreturnfori_loop(0,4,body,x)
Now if we were to run the same alias analysis, we’ll find that the buffers all no longer alias to each other and that we won’t need to insert defensive copies to be correct.
Therefore, the simple solution to removing these copies is to usefori_loop withunroll>=2.
deff(x):defbody(i,x):fut=ppermute_start(x)y=ppermute_done(fut)returnyreturnfori_loop(0,8,body,x,unroll=2)
That’s sufficient to implement this loop without extra copies!
Passing futures across loop boundaries#
Let’s now look at an even more advanced example. We’ll implement the same program as before but stagger the loop, where we begin theppermute in a prologue before the loop, and wait on theppermute at the beginning of the loop.
deff(x):fut=ppermute_start(x)defbody(i,fut):x=ppermute_done(fut)fut=ppermute_start(x)returnfutfut=fori_loop(0,7,body,fut)returnppermute_done(fut)
In this example, rather than passing a valuex from one loop to another we are passing a future value.
Let’s unpack the future again to see what’s happening.
deff(x):fut=ppermute_start(x)defbody(i,fut):*sems,x,out=futx=ppermute_done((*sems,x,out))(*sems,x,out)=ppermute_start(x)return(*sems,x,out)(*sems,x,out)=fori_loop(0,7,body,x)returnppermute_done((*sems,x,out))
So we’re explicitly threading the semaphores, the input buffer, and the target output buffer as a loop carry. What happens if we run alias analysis now? Well, we’ll run into the same aliasing issue as in the previous section wherex andout will be aliased to each other. XLA will introduce a copy.
deff(x):fut=ppermute_start(x)defbody(i,fut):*sems,x,out=futout=copy(out)x=ppermute_done((*sems,x,out))(*sems,x,out)=ppermute_start(x)return(*sems,x,out)(*sems,x,out)=fori_loop(0,7,body,x)returnppermute_done((*sems,x,out))
In this case, we inserted a copy onout. However, this is a really bad scenario becauseout is being actively copied into! Even if we insert a copy onx, we will also run into issues because thenx’s lifetime will not extend to theppermute_done. This is very very bad! We will not only get copies, but we will also get incorrect results!
The solution, as we observed before, is to avoid the copies by avoiding aliasing all the buffers via unrolling. So, if we do:
deff(x):fut=ppermute_start(x)defbody(i,fut):x=ppermute_done(fut)fut=ppermute_start(x)returnfutfut=fori_loop(0,7,body,x,unroll=2)returnppermute_done(fut)
our program should now be correct.
Putting it all together#
So we’ve come up with some rules of thumb:
If we have operations dependent on the input value to the
ppermute, unpack the future to use the aliased value instead of the original value.Use
unroll>=2when doingppermutes in a loop body.
Let’s combine everything into one function that doesppermutes in a loop and accumulates the result.
deff(x):out=jnp.zeros_like(x)fut=(*sems,x,out)=ppermute_start(x)out=out+xdefbody(i,carry):out,fut=carryx=ppermute_done(fut)fut=(*sems,x,out)=ppermute_start(x)out=out+xreturnout,futout,fut=fori_loop(0,7,body,(out,fut),unroll=2)returnout,ppermute_done(fut)
Note that in this example, we don’t needoptimization_barriers because the loop boundary acts as a scheduling barrier, splitting up thestarts anddones.
That’s it, we are done! This will be the official API for doing async ops in Pallas. Thank you everyone! Mission accomplished!
OR IS IT?
Revenge of the State#
While it seems we have worked around copies and incorrectness issues by using some clever tricks, we are still in an awkward position. This API is powerful, but has many many footguns and caveats. There are likely far many more edge cases we will need to deal with that even require deep knowledge of XLA to predict or understand. Should we release an API like this? Or is there an alternative?
Well, the answer may have been in front of us this whole time.
Let’s run through this whole exercise one more time,except, let’s write the stateful version. This means each of our custom async ops now operate onRefs instead of values.
defppermute_start_stateful(x_ref,y_ref)->tuple[Semaphore,Semaphore]:...defppermute_done_stateful(send_sem,recv_sem,x_ref,y_ref)->None:...
Let’s assume we can implement these in Pallas and see what our new programs will look like. Let’s start with a basic collective permute:
deff(x):x_ref=make_ref(x)y_ref=make_ref(zeros_like(x))fut=ppermute_start_stateful(x_ref,y_ref)ppermute_done_stateful(*fut,x_ref,y_ref)returny_ref[...]
It’s a little bit more verbose than our original value-based version, but it has a few key differences. The first is that we create an “empty”Ref to receive the result of theppermute, unlike the value-based version, which creates a value for us. One neat thing is that the lifetime ofx_ref is clear here: it lives untilppermute_done_stateful. We don’t need to “sneak” thex value into the op like we did before.
Another difference becomes more clear when we try adding an op between thestart/done.
deff(x):x_ref=make_ref(x)y_ref=make_ref(zeros_like(x))fut=ppermute_start_stateful(x_ref,y_ref)x_ref[...]+=1ppermute_done_stateful(*fut,x_ref,y_ref)returny_ref[...]
Before, we ran into scheduling ambiguity, where XLA could re-order the add w.r.t. theppermute. With stateful semantics, we actually add in an ordering constraint!x_ref[...]+=1 mutatesx_ref so it can’t be moved wrt toppermute_done_stateful. JAX can inject these scheduling constraints as part of the lowering to HLO.
The final key difference is evident when we try our loop examples.
deff(x):x_ref=make_ref(x)y_ref=make_ref(zeros_like(x))defbody(i,_):fut=ppermute_start_stateful(x_ref,y_ref)ppermute_done_stateful(*fut,x_ref,y_ref)# Now switch to y_ref -> x_reffut=ppermute_start_stateful(y_ref,x_ref)ppermute_done_stateful(*fut,y_ref,x_ref)fori_loop(0,8//2,body,None)returnx_ref[...]
Because of the requirement that we have a separate buffer ready to receive theppermute, we were forced to write our code in such a way that unrolls it! There is no way to write the version in XLA that requires copying because that would involve appermute that sends from aRef into itself, which doesn’t really make sense.
To handle this without the manual unrolling, we’d create a scratch buffer with a leading2 dimension that acts as the send/recv target across iterations, switching each one. This is the same pattern we use internally in Pallas kernels when writing manually overlapped kernels.
The realization here is that being stateful forces us to deal with a lot of the issues that pop up with value semantics earlier on. We define them away!
Scheduling - stateful ops that have
Refs as inputs force an ordering of our program. Note that this will schedule operations on the sameRefwrt to each other. We might also need anopt_barrier_statefulto enforce more ordering constraints.Lifetimes -
Reflifetimes can be scoped viarun_stateor could be inputs to stateful ops.Defensive copies - Using
Refs forces us to handle buffer assignment “manually” and the lowering can ensure the aliasing works out to avoid any copies.
Another important fundamental limitation is that we eventually stage out an HLO program where the live buffers and semaphores are represented as array value types. XLA does not provide guarantees about buffer lifetimes or which memory spaces they live in for these intermediate values.Therefore, it is possible XLA can copy array values even if they are actively being copied into by Pallas kernels. This is easy to verify in HLO but it is a sharp edge of using custom calls to represent asynchronous operations in HLO.
Conclusion#
We’ve gone over some tricky challenges when it comes to async ops in Pallas and JAX.Refs seem like a promising way of representing these ops that circumvents some of the issues that come up with value semantics. However, a downside is that it puts stateful JAX front and center, which we haven’t done yet outside of Pallas. It’s worth thinking whether we should educate users about stateful ops, or provide a more dangerous API. We also don’t know if everything we want to do is expressible viaRefs as well. We should also brainstorm alternatives to state to flesh out the design space. For example, what if XLA offered a first-class futures API that respected lifetimes, and it could automatically do things like double buffer loops with futures in them? That might be a viable alternative but the tradeoff would be giving more control to the compiler vs explicit control from the user.
