Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

Ahead-of-time lowering and compilation#

JAX’sjax.jit transformation returns a function that, when called,compiles a computation and runs it on accelerators (or the CPU). Asthe JIT acronym indicates, all compilation happensjust-in-time forexecution.

Some situations call forahead-of-time (AOT) compilation instead. When youwant to fully compile prior to execution time, or you want control over whendifferent parts of the compilation process take place, JAX has some options foryou.

First, let’s review the stages of compilation. Suppose thatf is afunction/callable output byjax.jit(), sayf=jax.jit(F) for some inputcallableF. When it is invoked with arguments, sayf(x,y) wherex andyare arrays, JAX does the following in order:

  1. Stage out a specialized version of the original Python callableF to an internal representation. The specialization reflects arestriction ofF to input types inferred from properties of theargumentsx andy (usually their shape and element type). JAXcarries out this specialization by a process that we calltracing. During tracing, JAX stages the specialization ofF toa jaxpr, which is a function in theJaxpr intermediatelanguage.

  2. Lower this specialized, staged-out computation to the XLA compiler’sinput language, StableHLO.

  3. Compile the lowered HLO program to produce an optimized executable forthe target device (CPU, GPU, or TPU).

  4. Execute the compiled executable with the arraysx andy as arguments.

JAX’s AOT API gives you direct control over each of these steps, plussome other features along the way. An example:

>>>importjax>>>deff(x,y):return2*x+y>>>x,y=3,4>>>traced=jax.jit(f).trace(x,y)>>># Print the specialized, staged-out representation (as Jaxpr IR)>>>print(traced.jaxpr){ lambda ; a:i32[] b:i32[]. let    c:i32[] = mul 2:i32[] a    d:i32[] = add c b  in (d,) }>>>lowered=traced.lower()>>># Print lowered HLO>>>print(lowered.as_text())module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {  func.func public @main(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32> {jax.result_info = "result"}) {    %c = stablehlo.constant dense<2> : tensor<i32>    %0 = stablehlo.multiply %c, %arg0 : tensor<i32>    %1 = stablehlo.add %0, %arg1 : tensor<i32>    return %1 : tensor<i32>  }}>>>compiled=lowered.compile()>>># Query for cost analysis, print FLOP estimate>>>compiled.cost_analysis()['flops']2.0>>># Execute the compiled function!>>>compiled(x,y)Array(10, dtype=int32, weak_type=True)

Note that the lowered objects can be used only in the same processin which they were lowered. For exporting use cases, see theExporting and serialization APIs.

See thejax.stages documentation for more details on what functionalitythe lowering and compiled functions provide.

All optional arguments tojit—such asstatic_argnums—are respected inthe corresponding tracing, lowering, compilation, and execution.

In the example above, we can replace the arguments totrace with any objectsthat haveshape anddtype attributes:

>>>i32_scalar=jax.ShapeDtypeStruct((),jnp.dtype('int32'))>>>jax.jit(f).trace(i32_scalar,i32_scalar).lower().compile()(x,y)Array(10, dtype=int32)

More generally,trace only needs its arguments to structurally supply what JAXmust know for specialization and lowering. For typical array arguments like theones above, this meansshape anddtype fields. For static arguments, bycontrast, JAX needs actual array values (more on thisbelow).

Invoking an AOT-compiled function with arguments that are incompatible with itstracing raises an error:

>>>x_1d=y_1d=jnp.arange(3)>>>jax.jit(f).trace(i32_scalar,i32_scalar).lower().compile()(x_1d,y_1d)...Traceback (most recent call last):TypeError:Argument types differ from the types for which this computation was compiled. The mismatches are:Argument 'x' compiled with int32[] and called with int32[3]Argument 'y' compiled with int32[] and called with int32[3]>>>x_f=y_f=jnp.float32(72.)>>>jax.jit(f).trace(i32_scalar,i32_scalar).lower().compile()(x_f,y_f)...Traceback (most recent call last):TypeError:Argument types differ from the types for which this computation was compiled. The mismatches are:Argument 'x' compiled with int32[] and called with float32[]Argument 'y' compiled with int32[] and called with float32[]

Relatedly, AOT-compiled functionscannot be transformed by JAX’s just-in-timetransformations such asjax.jit,jax.grad(), andjax.vmap().

Tracing with static arguments#

Tracing with static arguments underscores the interaction between optionspassed tojax.jit, the arguments passed totrace, and the arguments neededto invoke the resulting compiled function. Continuing with our example above:

>>>lowered_with_x=jax.jit(f,static_argnums=0).trace(7,8).lower()>>># Lowered HLO, specialized to the *value* of the first argument (7)>>>print(lowered_with_x.as_text())module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {  func.func public @main(%arg0: tensor<i32>) -> (tensor<i32> {jax.result_info = "result"}) {    %c = stablehlo.constant dense<14> : tensor<i32>    %0 = stablehlo.add %c, %arg0 : tensor<i32>    return %0 : tensor<i32>  }}>>>lowered_with_x.compile()(5)Array(19, dtype=int32, weak_type=True)

Note thattrace here takes two arguments as usual, but the subsequent compiledfunction accepts only the remaining non-static second argument. The static firstargument (value 7) is taken as a constant at lowering time and built into thelowered computation, where it is possibly folded in with other constants. Inthis case, its multiplication by 2 is simplified, resulting in the constant 14.

Although the second argument totrace above can be replaced by a hollowshape/dtype structure, it is necessary that the static first argument be aconcrete value. Otherwise, tracing errs:

>>>jax.jit(f,static_argnums=0).trace(i32_scalar,i32_scalar)Traceback (most recent call last):TypeError:unsupported operand type(s) for *: 'int' and 'ShapeDtypeStruct'>>>jax.jit(f,static_argnums=0).trace(10,i32_scalar).lower().compile()(5)Array(25, dtype=int32)

The results oftrace and oflower are not safe to serialize directly for usein a different process. SeeExporting and serialization for additional APIs for this purpose.

AOT-compiled functions cannot be transformed#

Compiled functions are specialized to a particular set of argument “types,” suchas arrays with a specific shape and element type in our running example. FromJAX’s internal point of view, transformations such asjax.vmap() alter thetype signature of functions in a way that invalidates the compiled-for typesignature. As a policy, JAX simply disallows compiled functions to be involvedin transformations. Example:

>>>defg(x):...assertx.shape==(3,2)...returnx@jnp.ones(2)>>>defmake_z(*shape):...returnjnp.arange(np.prod(shape)).reshape(shape)>>>z,zs=make_z(3,2),make_z(4,3,2)>>>g_jit=jax.jit(g)>>>g_aot=jax.jit(g).trace(z).lower().compile()>>>jax.vmap(g_jit)(zs)Array([[ 1.,  5.,  9.],       [13., 17., 21.],       [25., 29., 33.],       [37., 41., 45.]], dtype=float32)>>>jax.vmap(g_aot)(zs)Traceback (most recent call last):TypeError:Cannot apply JAX transformations to a function lowered and compiled for a particular signature. Detected argument of Tracer type <class 'jax._src.interpreters.batching.BatchTracer'>

A similar error is raised wheng_aot is involved in autodiff(e.g.jax.grad()). For consistency, transformation byjax.jit isdisallowed as well, even thoughjit does not meaningfully modify itsargument’s type signature.

Debug information and analyses, when available#

In addition to the primary AOT functionality (separate and explicit lowering,compilation, and execution), JAX’s various AOT stages also offer some additionalfeatures to help with debugging and gathering compiler feedback.

For instance, as the initial example above shows, lowered functions often offera text representation. Compiled functions do the same, and also offer cost andmemory analyses from the compiler. All of these are provided via methods on thejax.stages.Lowered andjax.stages.Compiled objects (e.g.,lowered.as_text() andcompiled.cost_analysis() above).You can obtain more debugging information, e.g., source location,by using thedebug_info parameter tolowered.as_text().

These methods are meant as an aid for manual inspection and debugging, not as areliably programmable API. Their availability and output vary by compiler,platform, and runtime. This makes for two important caveats:

  1. If some functionality is unavailable on JAX’s current backend, then themethod for it returns something trivial (andFalse-like). For example, ifthe compiler underlying JAX does not provide a cost analysis, thencompiled.cost_analysis() will beNone.

  2. If some functionality is available, there are still very limited guaranteeson what the corresponding method provides. The return value is not requiredto be consistent—in type, structure, or value—across JAX configurations,backends/platforms, versions, or even invocations of the method. JAX cannotguarantee that the output ofcompiled.cost_analysis() on one day willremain the same on the following day.

When in doubt, see the package API documentation forjax.stages.


[8]ページ先頭

©2009-2026 Movatter.jp