Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

Handling of closed-over constants

Handling of closed-over constants#

“Closed-over constants” are non-scalar arrays that are encountered during JAX tracingof a function and do not have dependencies on any of the function’s arguments.JAX operations such asjax.numpy andlax are staged out and do not createclosed-over constants.In the following example, the arraysa_jax_array andnp.full are closed-over constants, butjnp.fullis not. We refer below to closed-over constants simply as constants.

importnumpyasnpfromjaximportjitfromjaximportnumpyasjnpa_jax_array=jnp.ones((16,),dtype=np.float32)@jitdeff(x):returnx+a_jax_array+np.full((16,),42.)+jnp.full((16,),142.)

We describe below thefuture internal implementation details forconstants. As of July 2025, this is not yet the default implementation;it is enabled by the environment variableJAX_USE_SIMPLIFIED_JAXPR_CONSTANTS=True.See furtherbelow for the details of the previousimplementation, including its drawbacks.

Tracing#

When JAX tracing encounters a constant that is either an argument of a JAX primitiveor a function return, it is represented as acore.Literal, and is embeddedin theJaxpr along with the primitives that use them.The functioncore.is_literalable decides which constants are turned intocore.Literal. All scalar constants are turned intocore.Literal, along withnon-scalarnp.ndarray andjax.Array.

Lowering#

When lowering the code to HLO we could just emit astablehlo.constant operationfor acore.Literal, but this would have several disadvantages:

  • if the constant is ajax.Array (e.g., thea_jax_array above), then it ispulled from the device to the host during lowering, and it will later re-materializedon the device when the lowered module executes.This can increase the host memory usage, sometimes dramatically.Furthermore, if the constant is sharded on multiple devices thissharding is lost.

  • large constants increase the size of the HLO, especially ifthe same constant is used multiple times. Also, the XLA compiler will attemptto constant-fold them, resulting in warnings and slow compilation. Furthermore,we have observed that XLA constant-folding sometimes produces slightly differentnumerics compared to compiled code.See alsoLarge closed-over constants are inlined in the HLO code #29684.

Instead, during lowering we use the functioncore.jaxpr_const_args to scanaJaxpr and return a list of constants contained within, uniquified by theirid. Thecore.jaxpr_const_args is memoized for eachJaxpr and sub-Jaxpron which it is called.

All the lowered HLO functions will take one additional argumentfor each unique constant appearing in theJaxpr to which it corresponds.These arguments, referred to asconst_args,come after the dimension variable arguments, after thetoken arguments, and just before the actual array arguments.During lowering we maintain a mappingconst_lowering:dict[int,mlir.IrValues]from theid of the constants to the HLO values for the correspondingconst args.This mapping is stored in themlir.LoweringRuleContext and is usedbymlir.ir_constant: when a constant is encountered, we just reusethe existing lowering fromconst_lowering instead of emitting astablehlo.constant.

When we lower an HLO inner function (i.e., not themain function),we call againcore.jaxpr_const_argsto get the actual constants in the correspondingJaxpr. These areexpected to be among the constants for which we have aconst_lowering.The inner function will get its own smaller set ofconst_args andits ownconst_lowering mapping to be used when lowering the body.E.g., the functionmlir.lower_jaxpr_as_fun is one place where someof this happens.

The functionmlir.jaxpr_subcomp does not create a new HLO function,but instead creates a block within the current function. It usesthe enclosing function’sconst_lowering.

Note also that there will still bestablehlo.constant in the loweredcode, in three cases:

  • when the constant is a scalar; we want these constants to beavailable to XLA for constant folding.

  • when the constant did not appear in the traced program, and ishence not in theJaxpr. This can happen for constants thatarise during lowering, e.g., the lowering of some PRNG functionsinclude constants.

  • when we are exporting: at the moment, we do not hoist constant argswhen we export because the export serialization does not currently supportserialization of arrays.We use themlir.LoweringParameters.hoist_constants_as_args parameterto control this.

One additional complication is that some of the internal lowering functionsneed to take the argument avals and sometimes also the shardings andlayouts for the arguments. Furthermore, the avals, shardings, and layout forall arguments, including the const args,are used also after lowering also. Therefore, it is convenientto compute these fairly high in the call stack, e.g., inpxla.lower_sharding_computations, and pass them down.

For example, the functionsmlir.lower_jaxpr_to_module,pjit._pjit_cached_lower_jaxpr_to_fun, and,mlir.lower_jaxpr_to_funtakein_avals,in_shardings, andin_layouts thatthat include both the avals for const_args and for the regular args(the ones corresponding to theJaxpr.invars).They also take anum_const_args argument.

Compilation and execution#

The lowered MLIR module contains arguments for the const args, sothe compiled executable will need to be passed the const args.It is important to choose the right place where we prepend theconst args. For example, in the following code, the second invocationof the jitted functionf is expected to hit the C++ jit cache withoutany Python code executing.

const=jnp.array([42.])f=jax.jit(lambda:const)f()f()

(TODO: yashk2810 plans to write a description of how the jit caches work.)This means that theconst will have to be passed to the executable in C++(and thus stored inpxla.MeshExecutableFastpathData),and therefore the C++ cachemiss functions (e.g.,pjit._cpp_pjit.cache_miss,oraot_cache_miss inpxla.MeshExecutable.create_cpp_call)will not take the const args as arguments. Instead these cachemiss functions will have to prepend the const args.

The C++ fast path has support for const args starting with jaxlib 0.7.1.In prior versions, the fast path is disabled when there are const args.

To implement this scheme, we keep theconst_args instages.Lowering,stages.Lowered, andstages.CompiledCallParams.

Interestingly, when we serialize an executable, e.g., for the compilationcache, we do not need to serialize the closed over constants. The executableitself does not contain them, and needs to take them as const args.Whoever is going to deserialize the cached executable will have to passthe const args.

In AOT mode, the lowering and execution mayuse different values of thejax_enable_x64 configuration value.If the constants are 64-bitndarray we must use the same valueofjax_enable_x64 for lowering and execution.

Previous implementation#

This describes the current way we handle closed-over constants, asof July 2025 (as long asJAX_USE_SIMPLIFIED_CONSTANTS=False).

When JAX traces a function to aJaxpr it collects the closed-over valuesinto a set of constants, and adds a corresponding set ofconstvars to the Jaxpr(the actual arguments are represented byinvars).Most tracing functions, e.g.,trace_to_jaxpr_dynamic,return both theJaxpr and the constants.

In many places in the code we use a classcore.ClosedJaxpr that contains aJaxpr andconsts corresponding to theJaxpr.constvars.

There are several issues withClosedJaxpr:

  • the lowering of theconsts inClosedJaxpr results in inlinedstablehlo.constant, with all the issues described above.

  • Jaxpr andClosedJaxpr are used pervasively in JAX, often with thegeneric namejaxpr and it is not easy to tell which kind ofJaxpr we have.We have started to add type declarations, but in some places the codeis written withisinstance conditionals to work with both.

  • Since Jaxpr and ClosedJaxpr are sometimes used as caching keys,and they are hashed byid, we would like to memoize their construction.For example, the functionpe.closed_jaxprmemoizes the construction ofClosedJaxpr but only for the case when consts is empty.This is because sometimes consts are not hashable.

  • Handling the constants in ClosedJaxpr requires some extra care.E.g., there are places in the Mosaic lowering where we have not yet implementedthe handling of ClosedJaxpr with non-empty constants(e.g.here).

  • When we turn closed-over constants into inputs we have to be carefulduring transformations with how we handle these auxiliary inputs.


[8]ページ先頭

©2009-2025 Movatter.jp