Sequencing side-effects in JAX
Contents
Sequencing side-effects in JAX#
sharadmv@May 9 2022
Overview#
When we write JAX code, we can usually pretend we’re writing single-threaded, eagerly-executed Pythoneven though underneath the hood, JAX and its runtime may execute it asynchronously in the background.As long as we write pure (side-effect-free) code, these performance optimizations are usually invisible to us and don’t interfere with our single-threaded mental model.Asynchronous execution is great – we get performant, parallel code withouthaving to think about it at all!
However, in the presence of side-effects, the illusion begins to break down andthe cracks in our mental model start to show. Specifically, these differencesshow up when we think about theorder in which side-effects happen.
In this design note, we explore the interaction between JAX’s execution model,and the ordering of side-effects. We also provide a way of enforcing a“single-threaded” ordering of effects.
Background#
When we write the following Python code
deff():print("hello")return2defg():print("world")return3f()g()
we expect"hello" to be printed before"world". This might seem obviousbut consider the following JAX code:
@partial(jax.jit,device=<device0>)deff():return2@partial(jax.jit,device=<device1>)defg():return3f()g()
In many cases, JAX will executef andgin parallel, dispatchingthe computations onto different threads –g might actually be executedbeforef. Parallel execution is a nice performance optimization, especially if copyingto and from a device is expensive (see theasynchronous dispatch note for more details).In practice, however, we often don’t need tothink about asynchronous dispatch because we’re writing pure functions and onlycare about the inputs and outputs of functions – we’ll naturally block on futurevalues.
However, now imagine that we have ajax.print function that works inside ofJIT-ted JAX functions (host_callback.id_print is an example of this). Let’sreturn to the previous example except with prints in the mix.
@partial(jax.jit,device=<device0>)deff():jax.print("hello")return2@partial(jax.jit,device=<device1>)defg():jax.print("world")return3f()g()
Thanks to asynchronous dispatch, we could actually see"world" being printedbefore"hello". The reordering of the print side-effects breaks the illusionof a single-threaded execution model.
Another example of where side-effects can “reveal” out-of-order execution iswhen we compile JAX programs. Consider the following JAX code:
@jax.jitdeff(x):jax.print("hello")jax.print("world")returnx
Even though in Python, we’ve written the"hello" print before the"world" print,a compiler like XLA is free to reorder them because there’s no explicit data-dependence between the prints.
Motivation#
We’d like to support “ordered” effects. When we say ordered, we mean that the effectsoccur in the same order as we would if we were executing a single-threaded Python program.This is our main desideratum. In the presence of explicit parallelism likepmap oruser threads, we don’t need to maintain this behavior but at least if the user is notexplicitly requesting parallelism, we’d like to preserve a single-threaded ordering.
Before we dive in more, let’s first step back and ask ourselves if it is okayif we reorder effects in the name of performance, and conversely, do we need toenforce an ordering on effects at all? In some cases, we don’t need ordering.Maybe some side-effects shouldn’t adversely affect theperformance of a JAX program. However, for other side-effects, we maywant to enforce a single-threaded program order so users don’t get counterintuitivebehavior. Consider a logging effect.
@jax.jitdeff(x,y):log_value(x)log_value(y)f(1,2)
Iflog is mutating a global list, we might expect that we addx before addingy.For a more strict effect, we may want the option to order the effects.
Enforcing ordered effects#
The main tool we have to enforce the ordering of computations isdata-dependence.Simply put, if a functiong has an input that is the output of a functionf,f must be executed beforeg.
However, we may have side effects like prints that have no inputs at allso naively we couldn’t sequence them. We thus usetokens as a means of injectingartificial data-dependence into a computation.
What is a token? A token is just a dummy value that can be threaded in and out of a computation.By threading the same token in and out and several computations, we enforce that they have to happenin a certain order. Let’s take the previous print example and see what it would look like with tokensin the mix:
@jax.jitdeff(token,x):token=jax.print(token,"hello")token=jax.print(token,"world")returntoken,x
If we rewritejax.print to take in and return a token, we have now sequencedthe two prints since the input to the second print depends on the output of the first print.The actual value oftoken can be anything really, but we’ll see in practicethat the tokens are invisible to users.
Runtime tokens vs. compiler tokens#
Here we will actually start talking about implementation details. In practice, we’ll needtwo separate types of tokens to sequence effects: one for each of the aforementioned sourcesof reordering. We’ll needruntime tokens to sequence asynchronously dispatchedside-effecting computations and we’ll needcompiler tokens to sequence effects within computations.
In practice, our computation will be rewritten to look like this:
@jax.jitdeff(runtime_token,x):compiler_token=new_compiler_token()compiler_token=jax.print(compiler_token,"hello")compiler_token=jax.print(compiler_token,"world")returnruntime_token,x
Notice how the runtime tokens are only used at the JIT boundary and the compiler tokensare only within the compiled code. Compiler tokens are created during“lowering” (we convert Python code to a lower level representation like HLO or StableHLO)but runtime tokens need to be managed in Python since they’re being threaded in and outof JIT-ted functions.
Furthermore, notice that the runtime tokens are “disconnected”from the compiler tokens meaning there’s no data dependency between them. This couldpotentially be dangerous as if we will lose the data dependence between the bodiesof two dispatched function calls. However, if we assume “strict execution” – i.e.a dispatched function will only start execution when all of its inputs are readyand all of it outputs will become ready at the same time – we are safe to create afresh compiler token and return a non-output-dependent runtime token.
Managing runtime tokens#
To manage runtime tokens on behalf of the user, we’ll need to hook into JAX’s dispatch machinery.Whenever we call a JIT-ted function, we eventually bottom out in a function that looks likethis:
def_execute(compiled_computation,*args):outputs=compiled_computation.execute(*args)returnoutputs
At this point we need to “inject” the runtime tokens into the computationand “extract” them from the computation’s outputs:
def_execute(compiled_computation,*args):runtime_token=get_runtime_token()# Grab global tokenruntime_token,*outputs=compiled_computation.execute(runtime_token,*args)update_runtime_token(runtime_token)# Update global tokenreturnoutputs
What isruntime_token exactly? Well we need to be able to pass it into acompiled_computation,which means it needs to be some sort of array (for now, since there’s no shared token representationinside and outside compiled JAX code). In practice we can use a(0,)-shaped array to minimize overheads.
We also need to think about the multiple device use case, e.g. the first example wherewe first call a JIT-ted function on device 0 and then one on device 1.In that case, we need to alsocopy the runtime token returned from the first computation (which lives on device 0)to device 1 so we can pass it into the second computation. If two subsequent computations share the same device,this copy is not necessary.
Adding compiler tokens#
When we lower Python code to HLO or StableHLO we need to create a token at the start of the computation andensure they are available when we have side-effecting computations that need to be ordered. The side-effectingcomputations will take the token as input and return it as an output.
The implementation of this token threading involves upgrading the JAX lowering machinery to dothis bookkeeping automatically.The main challenges involve dealing with higher-order primitives like call primitivesand control-flow primitives. We won’t go into details on how to handle those in this design note.
Blocking on output tokens#
Adding support for runtime and compiler tokens for side-effecting computations is important for sequencingbut there’s also another subtle use-case for tokens, which is blocking on side-effecting computations.Even if we don’t want a side-effecting computation to beordered we may still want to wait on itscompletion. Currently we havejax.block_until_ready, which waits until a future value has itsresult ready. However, with side-effecting computations, we may have functions that don’t have a returnvalue but are still executing a side-effect. Take the simple example here:
@jax.jitdeff():jax.print("hello world")returnf()# Executed asynchronously
This compiled computation takes no explicit inputs and has no explicit outputs. If it was an ordered print effect,we could block on the returned runtime token, However,when this is an unordered computation we don’t do any token threading. How do we wait forf() tofinish executing when we have no output value to callblock_until_ready on? Well, we could apply our sametoken strategy except we only return runtime tokens and don’t take them as inputs. This will give usa value to block on that will only be ready oncef() is done being executed. We’ll call these tokensoutput tokens. We end up with a function that looks like this:
@jax.jitdeff():jax.print("hello world")returnnew_runtime_token()f()# Executed asynchronously
Underneath the hood, we’ll manage the output tokens in the same way we manage the runtime tokens butprovide a method for users to block on the current set of output tokens. Unlike runtime tokens,output tokens need to bedevice-specific.Consider a single device use-case:
@jax.jitdeff():jax.print("hello")@jax.jitdefg():jax.print("world")f()g()
Sincef() andg() are executed on the same device, blocking ong()’s output tokeneffectively blocks onf() since (as of now!), the JAX runtime does not interleave computationsexecuted on the same device. We’ll have to revise this entire design if that changes, of course.
However, consider the two device use-case:
@partial(jax.jit,device=<device0>)deff():jax.print("hello")@partial(jax.jit,device=<device1>)defg():jax.print("world")f()g()
Here we don’t want to explicitly sequencef() andg() but want to wait for both of them to finish.We’ll need one output token forf() and one forg() and we’ll block on both of those tokens:
@partial(jax.jit,device=<device0>)deff():jax.print("hello")returnnew_runtime_token()@partial(jax.jit,device=<device1>)defg():jax.print("world")returnnew_runtime_token()t0=f()t1=g()block_until_ready((t0,t1))
We’ll thus need a per-device output token so we can avoid sequencing computations on differentdevices while offering the ability to block on side-effecting computations. We end up with the following(approximate) change to the JAX dispatch machinery:
def_execute(compiled_computation,*args):output_token,*outputs=compiled_computation.execute(runtime_token,*args)update_output_token(output_token,compiled_computation.device)returnoutputs
We’ll also need to expose a function to that blocks on the output token:
defeffects_barrier():output_token.block_until_ready()
Note that blocking on output tokens may not be fairly common since most JAX computations will returna value to block on. However, output tokens are helpful for testing and profiling, and are good tosupport so that we have a consistent and cohesive effect system.
Some more details#
All of the aforementioned token management infrastructure will bethread-local. This meansthat each user thread will have their own independent stream of runtime tokens. Sequencingis only promised at a user thread level.
In practice, we have one runtime token per effect. Different instances of that effect will besequenced. This is to avoid sequencing effectul computations that may not have any relation to eachother. Technically this goes against our original goal though of enforcing a single-threaded Pythonprogram ordering, but this is a tradeoff that could be modulated by having both “effect”-specific tokensand “global” tokens.
