Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

Just-in-time compilation#

In this section, we will further explore how JAX works, and how we can make it performant.We will discuss thejax.jit() transformation, which will performJust In Time (JIT)compilation of a JAX Python function so it can be executed efficiently in XLA.

How JAX transformations work#

In the previous section, we discussed that JAX allows us to transform Python functions.JAX accomplishes this by reducing each function into a sequence ofprimitive operations, eachrepresenting one fundamental unit of computation.

One way to see the sequence of primitives behind a function is usingjax.make_jaxpr():

importjaximportjax.numpyasjnpglobal_list=[]deflog2(x):global_list.append(x)ln_x=jnp.log(x)ln_2=jnp.log(2.0)returnln_x/ln_2print(jax.make_jaxpr(log2)(3.0))
{lambda; a:f32[].letb:f32[] = log a    c:f32[] = log 2.0:f32[]    d:f32[] = div b cin(d,) }

TheJAX internals: The jaxpr language section of the documentation provides more information on the meaning of the above output.

Importantly, notice that the jaxpr does not capture the side-effect present in the function: there is nothing in it corresponding toglobal_list.append(x).This is a feature, not a bug: JAX transformations are designed to understand side-effect-free (a.k.a. functionally pure) code.Ifpure function andside-effect are unfamiliar terms, this is explained in a little more detail in🔪 JAX - The Sharp Bits 🔪: Pure Functions.

Impure functions are dangerous because under JAX transformations they are likely not to behave as intended; they might fail silently, or produce surprising downstream errors like leakedTracers.Moreover, JAX often can’t detect when side effects are present.(If you want debug printing, usejax.debug.print(). To express general side-effects at the cost of performance, seejax.experimental.io_callback().To check for tracer leaks at the cost of performance, use withjax.check_tracer_leaks()).

When tracing, JAX wraps each argument by atracer object. These tracers then record all JAX operations performed on them during the function call (which happens in regular Python). Then, JAX uses the tracer records to reconstruct the entire function. The output of that reconstruction is the jaxpr. Since the tracers do not record the Python side-effects, they do not appear in the jaxpr. However, the side-effects still happen during the trace itself.

Note: the Pythonprint() function is not pure: the text output is a side-effect of the function. Therefore, anyprint() calls will only happen during tracing, and will not appear in the jaxpr:

deflog2_with_print(x):print("printed x:",x)ln_x=jnp.log(x)ln_2=jnp.log(2.0)returnln_x/ln_2print(jax.make_jaxpr(log2_with_print)(3.))
printed x: JitTracer<~float32[]>{lambda; a:f32[].letb:f32[] = log a    c:f32[] = log 2.0:f32[]    d:f32[] = div b cin(d,) }

See how the printedx is aTraced object? That’s the JAX internals at work.

The fact that the Python code runs at least once is strictly an implementation detail, and so shouldn’t be relied upon. However, it’s useful to understand as you can use it when debugging to print out intermediate values of a computation.

A key thing to understand is that a jaxpr captures the function as executed on the parameters given to it.For example, if we have a Python conditional, the jaxpr will only know about the branch we take:

deflog2_if_rank_2(x):ifx.ndim==2:ln_x=jnp.log(x)ln_2=jnp.log(2.0)returnln_x/ln_2else:returnxprint(jax.make_jaxpr(log2_if_rank_2)(jax.numpy.array([1,2,3])))
{lambda; a:i32[3].letin(a,) }

JIT compiling a function#

As explained before, JAX enables operations to execute on CPU/GPU/TPU using the same code.Let’s look at an example of computing aScaled Exponential Linear Unit(SELU), anoperation commonly used in deep learning:

importjaximportjax.numpyasjnpdefselu(x,alpha=1.67,lambda_=1.05):returnlambda_*jnp.where(x>0,x,alpha*jnp.exp(x)-alpha)x=jnp.arange(1000000)%timeit selu(x).block_until_ready()
3.63 ms ± 79.7 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

The code above is sending one operation at a time to the accelerator. This limits the ability of the XLA compiler to optimize our functions.

Naturally, what we want to do is give the XLA compiler as much code as possible, so it can fully optimize it. For this purpose, JAX provides thejax.jit() transformation, which will JIT compile a JAX-compatible function. The example below shows how to use JIT to speed up the previous function.

selu_jit=jax.jit(selu)# Pre-compile the function before timing...selu_jit(x).block_until_ready()%timeit selu_jit(x).block_until_ready()
281 μs ± 1.68 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Here’s what just happened:

  1. We definedselu_jit as the compiled version ofselu.

  2. We calledselu_jit once onx. This is where JAX does its tracing – it needs to have some inputs to wrap in tracers, after all. The jaxpr is then compiled using XLA into very efficient code optimized for your GPU or TPU. Finally, the compiled code is executed to satisfy the call. Subsequent calls toselu_jit will use the compiled code directly, skipping the python implementation entirely.(If we didn’t include the warm-up call separately, everything would still work, but then the compilation time would be included in the benchmark. It would still be faster, because we run many loops in the benchmark, but it wouldn’t be a fair comparison.)

  3. We timed the execution speed of the compiled version. (Note the use ofblock_until_ready(), which is required due to JAX’sAsynchronous dispatch).

Why can’t we just JIT everything?#

After going through the example above, you might be wondering whether we should simply applyjax.jit() to every function. To understand why this is not the case, and when we should/shouldn’t applyjit, let’s first check some cases where JIT doesn’t work.

# Condition on value of x.deff(x):ifx>0:returnxelse:return2*xjax.jit(f)(10)# Raises an error
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].Theerroroccurredwhiletracingthefunctionfat/tmp/ipykernel_1892/2956679937.py:3forjit.ThisconcretevaluewasnotavailableinPythonbecauseitdependsonthevalueoftheargumentx.Seehttps://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError
# While loop conditioned on x and n.defg(x,n):i=0whilei<n:i+=1returnx+ijax.jit(g)(10,20)# Raises an error
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].Theerroroccurredwhiletracingthefunctiongat/tmp/ipykernel_1892/722961019.py:3forjit.ThisconcretevaluewasnotavailableinPythonbecauseitdependsonthevalueoftheargumentn.Seehttps://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError

The problem in both cases is that we tried to condition the trace-time flow of the program using runtime values.Traced values within JIT, likex andn here, can only affect control flow via their static attributes: such asshape ordtype, and not via their values.For more detail on the interaction between Python control flow and JAX, seeControl flow and logical operators with JIT.

One way to deal with this problem is to rewrite the code to avoid conditionals on value. Another is to use specialControl flow operators likejax.lax.cond(). However, sometimes that is not possible or practical.In that case, you can consider JIT-compiling only part of the function.For example, if the most computationally expensive part of the function is inside the loop, we can JIT-compile just that inner part (though make sure to check the next section on caching to avoid shooting yourself in the foot):

# While loop conditioned on x and n with a jitted body.@jax.jitdefloop_body(prev_i):returnprev_i+1defg_inner_jitted(x,n):i=0whilei<n:i=loop_body(i)returnx+ig_inner_jitted(10,20)
Array(30, dtype=int32, weak_type=True)

Marking arguments as static#

If we really need to JIT-compile a function that has a condition on the value of an input, we can tell JAX to help itself to a less abstract tracer for a particular input by specifyingstatic_argnums orstatic_argnames.The cost of this is that the resulting jaxpr and compiled artifact depends on the particular value passed, and so JAX will have to re-compile the function for every new value of the specified static input.It is only a good strategy if the function is guaranteed to see a limited set of static values.

f_jit_correct=jax.jit(f,static_argnums=0)print(f_jit_correct(10))
10
g_jit_correct=jax.jit(g,static_argnames=['n'])print(g_jit_correct(10,20))
30

To specify such arguments when usingjit as a decorator, a common pattern is to use python’sfunctools.partial():

fromfunctoolsimportpartial@partial(jax.jit,static_argnames=['n'])defg_jit_decorated(x,n):i=0whilei<n:i+=1returnx+iprint(g_jit_decorated(10,20))
30

JIT and caching#

With the compilation overhead of the first JIT call, understanding how and whenjax.jit() caches previous compilations is key to using it effectively.

Suppose we definef=jax.jit(g). When we first invokef, it will get compiled, and the resulting XLA code will get cached. Subsequent calls off will reuse the cached code.This is howjax.jit makes up for the up-front cost of compilation.

If we specifystatic_argnums, then the cached code will be used only for the same values of arguments labelled as static. If any of them change, recompilation occurs.If there are many values, then your program might spend more time compiling than it would have executing ops one-by-one.

Avoid callingjax.jit() on temporary functions defined inside loops or other Python scopes.For most cases, JAX will be able to use the compiled, cached function in subsequent calls tojax.jit().However, because the cache relies on the hash of the function, it becomes problematic when equivalent functions are redefined.This will cause unnecessary compilation each time in the loop:

fromfunctoolsimportpartialdefunjitted_loop_body(prev_i):returnprev_i+1defg_inner_jitted_partial(x,n):i=0whilei<n:# Don't do this! each time the partial returns# a function with different hashi=jax.jit(partial(unjitted_loop_body))(i)returnx+idefg_inner_jitted_lambda(x,n):i=0whilei<n:# Don't do this!, lambda will also return# a function with a different hashi=jax.jit(lambdax:unjitted_loop_body(x))(i)returnx+idefg_inner_jitted_normal(x,n):i=0whilei<n:# this is OK, since JAX can find the# cached, compiled functioni=jax.jit(unjitted_loop_body)(i)returnx+iprint("jit called in a loop with partials:")%timeit g_inner_jitted_partial(10, 20).block_until_ready()print("jit called in a loop with lambdas:")%timeit g_inner_jitted_lambda(10, 20).block_until_ready()print("jit called in a loop with caching:")%timeit g_inner_jitted_normal(10, 20).block_until_ready()
jit called in a loop with partials:330 ms ± 19.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)jit called in a loop with lambdas:316 ms ± 7.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)jit called in a loop with caching:1.45 ms ± 3.23 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

[8]ページ先頭

©2009-2025 Movatter.jp