Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

Exporting and serializing staged-out computations#

TheAhead-of-time lowering and compilation APIs produceobjects that can be used for debugging or for compilation andexecution in the same process.Sometimes you want to serialize a lowered JAX function forcompilation and execution in a separate process, perhapsat a later time. This would allow you to:

  • compile and execute the function in another process or machinewithout requiring access to the JAX program,and without having to repeat the staging-out and lowering, e.g.,in an inference system.

  • trace and lower a function on a machine that does not have accessto the accelerator for which you want to later compile and executethe function.

  • archive a snapshot of a JAX function, e.g., to be able toreproduce later your results.Note: check out thecompatibilityguarantees for this use case.

For more details see thejax.export API reference.

Here is an example:

>>>importre>>>importnumpyasnp>>>importjax>>>fromjaximportexport>>>deff(x):return2*x*x>>>exported:export.Exported=export.export(jax.jit(f))(...jax.ShapeDtypeStruct((),np.float32))>>># You can inspect the Exported object>>>exported.fun_name'f'>>>exported.in_avals(ShapedArray(float32[]),)>>>print(re.search(r".*@main.*",exported.mlir_module()).group(0))  func.func public @main(%arg0: tensor<f32> loc("x")) -> (tensor<f32> {jax.result_info = "result"}) {>>># And you can serialize the Exported to a bytearray.>>>serialized:bytearray=exported.serialize()>>># The serialized function can later be rehydrated and called from>>># another JAX computation, possibly in another process.>>>rehydrated_exp:export.Exported=export.deserialize(serialized)>>>rehydrated_exp.in_avals(ShapedArray(float32[]),)>>>defcallee(y):...return3.*rehydrated_exp.call(y*4.)>>>callee(1.)Array(96., dtype=float32)

Serialization is broken down into two stages:

  1. exporting to produce anjax.export.Exported object that containsthe StableHLO for the lowered function along with the metadata necessary tocall it from another JAX function. We have plans to add code to generateExported objects from TensorFlow, and to useExported objects fromTensorFlow and PyTorch.

  2. the actual serialization to a byte array using the flatbuffers format.SeeInteroperation with TensorFlow foran alternative serialization to TensorFlow graph that can be usedfor interoperation with TensorFlow.

Support for reverse-mode AD#

Serialization can optionally support higher-order reverse-mode AD. This is doneby serializing thejax.vjp() of the primal function along with the primal function,up to a user-specified order (default is 0, meaning that the rehydratedfunction cannot be differentiated):

>>>importjax>>>fromjaximportexport>>>fromtypingimportCallable>>>deff(x):return7*x*x*x>>># Serialize 3 levels of VJP along with the primal function>>>blob:bytearray=export.export(jax.jit(f))(1.).serialize(vjp_order=3)>>>rehydrated_f:Callable=export.deserialize(blob).call>>>rehydrated_f(0.1)# 7 * 0.1^3Array(0.007, dtype=float32)>>>jax.grad(rehydrated_f)(0.1)# 7*3 * 0.1^2Array(0.21000001, dtype=float32)>>>jax.grad(jax.grad(rehydrated_f))(0.1)# 7*3*2 * 0.1Array(4.2, dtype=float32)>>>jax.grad(jax.grad(jax.grad(rehydrated_f)))(0.1)# 7*3*2Array(42., dtype=float32)>>>jax.grad(jax.grad(jax.grad(jax.grad(rehydrated_f))))(0.1)Traceback (most recent call last):ValueError:No VJP is available

Note that the VJP function is computed lazily while serializing,when the JAX program is still available.This means that it respects all features of JAX VJP,e.g.,jax.custom_vjp() andjax.remat().

Note that the rehydrated function does not support any othertransformations, e.g., forward-mode AD (jvp), orjax.vmap().

Compatibility guarantees#

You should not use the raw StableHLO that is obtained from just lowering(jax.jit(f).lower(1.).compiler_ir())for archival and for compilation in another process, for several reasons.

First, the compilation may use a different version of the compiler, supporting adifferent version of StableHLO. Thejax.export module takescare of this by using theportable-artifact feature of StableHLOto deal with the possible evolution of the StableHLO opset.

Compatibility guarantees for custom calls#

Second, the raw StableHLO may contain custom calls referencing C++functions.JAX uses custom calls for lowering of a small number of primitives,e.g., linear algebra primitives, sharding annotations, or Pallas kernels.These do not fall under the compatibility guarantees for StableHLO.The C++ implementations of these functions change rarely, but they can change.

jax.export makes the following export compatibility guarantees:A JAX exported artifact can be compiled and executed by a compiler andJAX runtime system that are:

  • up to 6 months newer than the version of JAX used for exporting(we say that JAX export offers6 months backward compatibility).This is useful if we want to archive the exported artifact to be compiled and executed later.

  • up to 3 weeks older than the version of JAX used for exporting(we say that JAX export offers3 weeks forward compatibility).This is useful if we want to compile and run an exported artifact with aconsumer that was built and deployed before the export, e.g.,an inference system that is already deployed when the exporting is done.

(The particular compatibility window lengths are the same that JAXpromised for jax2tf,and are based onTensorFlow Compatibility.The terminology “backward compatibility” is from the perspective of the consumer,e.g., the inference system.)

Whatmatters is when the exporting and consuming components were built,not the time when the exporting and the compilation happen.For external JAX users, it ispossible to run JAX and jaxlib at different versions;what matters is when the jaxlib release was built.

To reduce chances of incompatibility, internal JAX users should:

  • rebuild and redeploy consumer systems as frequently as possible.

and external users should:

  • run the exporting and consumer systems with the same version of jaxlib, whenever possible, and

  • export for archivalwith the latest released version of jaxlib.

The compatibility guarantees do not apply if you bypass thejax.export APIsto obtain the StableHLO code.

In order to ensure forward compatibility, when we change the JAX lowering rulesto use a new custom call target, JAX will refrain for 3 weeks to use the newtarget. To use the latest lowering rules, you can pass the--jax_export_ignore_forward_compatibility=1 configuration flagor theJAX_EXPORT_IGNORE_FORWARD_COMPATIBILITY=1 environment variable.

Only a subset of custom calls are guaranteed stable and havecompatibility guarantees (see list).We continuouslyadd more custom call targets to the allowed list along with backwardscompatibility tests. If you try to serializecode that invokes other custom call targets you will get an errorduring exporting.

If you want to disable this safety check for a specific custom call,e.g., with targetmy_target, you can addexport.DisabledSafetyCheck.custom_call("my_target") to thedisabled_checks parameter of theexport method,as in the following example:

>>>importjax>>>fromjaximportexport>>>fromjaximportlax>>>fromjax._srcimportcore>>>fromjax._src.interpretersimportmlir>>># Define a new primitive backed by a custom call>>>new_prim=core.Primitive("new_prim")>>>_=new_prim.def_abstract_eval(lambdax:x)>>>_=mlir.register_lowering(new_prim,lambdactx,o:mlir.custom_call("my_new_prim",operands=[o],result_types=[o.type]).results)>>>print(jax.jit(new_prim.bind).lower(1.).compiler_ir())module @jit_bind attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {  func.func public @main(%arg0: tensor<f32>) -> (tensor<f32> {jax.result_info = "result"}) {    %0 = stablehlo.custom_call @my_new_prim(%arg0) {api_version = 2 : i32, backend_config = ""} : (tensor<f32>) -> tensor<f32>    return %0 : tensor<f32>  }}>>># If we try to export, we get an error>>>export.export(jax.jit(new_prim.bind))(1.)Traceback (most recent call last):ValueError:Cannot serialize code with custom calls whose targets have no compatibility guarantees: my_new_bind>>># We can avoid the error if we pass a `DisabledSafetyCheck.custom_call`>>>exp=export.export(...jax.jit(new_prim.bind),...disabled_checks=[export.DisabledSafetyCheck.custom_call("my_new_prim")])(1.)

SeeEnsuring forward and backward compatibility for developer information regardingensuring compatibility.

Cross-platform and multi-platform export#

JAX lowering is platform specific for a small number of JAX primitives.By default, the code is lowered and exported for the acceleratorpresent on the exporting machine:

>>>fromjaximportexport>>>export.default_export_platform()'cpu'

There is a safety check that will raise an error when trying to compileanExported object on a machine that does not have the acceleratorfor which the code was exported.

You can specify explicitly for what platforms the code should be exported.This allows you to specify a different accelerator than you haveavailable at export time,and it even allows you to specify multi-platform export toobtain anExported object that can be compiled and executedon multiple platforms.

>>>importjax>>>fromjaximportexport>>>fromjaximportlax>>># You can specify the export platform, e.g., `tpu`, `cpu`, `cuda`, `rocm`>>># even if the current machine does not have that accelerator.>>>exp=export.export(jax.jit(lax.cos),platforms=['tpu'])(1.)>>># But you will get an error if you try to compile `exp`>>># on a machine that does not have TPUs.>>>exp.call(1.)Traceback (most recent call last):ValueError:Function 'cos' was lowered for platforms '('tpu',)' but it is used on '('cpu',)'.>>># We can avoid the error if we pass a `DisabledSafetyCheck.platform`>>># parameter to `export`, e.g., because you have reasons to believe>>># that the code lowered will run adequately on the current>>># compilation platform (which is the case for `cos` in this>>># example):>>>exp_unsafe=export.export(jax.jit(lax.cos),...platforms=['tpu'],...disabled_checks=[export.DisabledSafetyCheck.platform()])(1.)>>>exp_unsafe.call(1.)Array(0.5403023, dtype=float32, weak_type=True)# and similarly with multi-platform lowering>>>exp_multi=export.export(jax.jit(lax.cos),...platforms=['tpu','cpu','cuda'])(1.)>>>exp_multi.call(1.)Array(0.5403023, dtype=float32, weak_type=True)

For multi-platform export, the StableHLO will contain multiplelowerings but only for those primitives that require it, so theresulting module size should be only marginally larger than thesize of a module with default export.As an extreme case, when serializing a module without anyprimitives with platform-specific lowering, you will getthe same StableHLO as for the single-platform export.

>>>importjax>>>fromjaximportexport>>>fromjaximportlax>>># A largish function>>>deff(x):...foriinrange(1000):...x=jnp.cos(x)...returnx>>>exp_single=export.export(jax.jit(f))(1.)>>>len(exp_single.mlir_module_serialized)9220>>>exp_multi=export.export(jax.jit(f),...platforms=["cpu","tpu","cuda"])(1.)>>>len(exp_multi.mlir_module_serialized)9282

Shape polymorphic export#

When used in JIT mode, JAX will trace and lower a function separatelyfor each combination of input shapes. When exporting, it is possiblein some cases to use dimension variables for some input dimensionsin order to obtain an exported artifact that can be used with multiplecombinations of input shapes.

See theShape polymorphism documentation.

Device-polymorphic export#

An exported artifact may contain sharding annotations for inputs,outputs and for some intermediates, but these annotations do not referdirectly to the actual physical devices that existed at exporting time.Instead, the sharding annotations refer to logical devices. Thismeans that you can compile and run the exported artifacts on differentphysical devices that were used for exporting.

The cleanest way to achieve a device-polymorphic export is touse shardings constructed with ajax.sharding.AbstractMesh,which contains only the mesh shape and axis names. But,you can achieve the same results if you use shardingsconstructed for a mesh with concrete devices, since the actualdevices in the mesh are ignored for tracing and lowering:

>>>importjax>>>fromjaximportexport>>>fromjax.shardingimportAbstractMesh,Mesh,NamedSharding>>>fromjax.shardingimportPartitionSpecasP>>>>>># Use an AbstractMesh for exporting>>>export_mesh=AbstractMesh((4,),("a",))>>>deff(x):...returnx.T>>>exp=export.export(jax.jit(f))(...jax.ShapeDtypeStruct((32,),dtype=np.int32,...sharding=NamedSharding(export_mesh,P("a"))))>>># `exp` knows for how many devices it was exported.>>>exp.nr_devices4>>># and it knows the shardings for the inputs. These will be applied>>># when the exported is called.>>>exp.in_shardings_hlo({devices=[4]<=[4]},)>>># You can also use a concrete set of devices for exporting>>>concrete_devices=jax.local_devices()[:4]>>>concrete_mesh=Mesh(concrete_devices,("a",))>>>exp2=export.export(jax.jit(f))(...jax.ShapeDtypeStruct((32,),dtype=np.int32,...sharding=NamedSharding(concrete_mesh,P("a"))))>>># You can expect the same results>>>assertexp.in_shardings_hlo==exp2.in_shardings_hlo>>># When you call an Exported, you must use a concrete set of devices>>>arg=jnp.arange(8*4)>>>res1=exp.call(jax.device_put(arg,...NamedSharding(concrete_mesh,P("a"))))>>># Check out the first 2 shards of the result>>>[f"device={s.device} index={s.index}"forsinres1.addressable_shards[:2]]['device=TFRT_CPU_0 index=(slice(0, 8, None),)', 'device=TFRT_CPU_1 index=(slice(8, 16, None),)']>>># We can call `exp` with some other 4 devices and another>>># mesh with a different shape, as long as the number of devices is>>># the same.>>>other_mesh=Mesh(np.array(jax.local_devices()[2:6]).reshape((2,2)),("b","c"))>>>res2=exp.call(jax.device_put(arg,...NamedSharding(other_mesh,P("b"))))>>># Check out the first 2 shards of the result. Notice that the output is>>># sharded similarly; this means that the input was resharded according to the>>># exp.in_shardings.>>>[f"device={s.device} index={s.index}"forsinres2.addressable_shards[:2]]['device=TFRT_CPU_2 index=(slice(0, 8, None),)', 'device=TFRT_CPU_3 index=(slice(8, 16, None),)']

It is an error to try to invoke an exported artifact with a different numberof devices than it was exported for:

>>>importjax>>>fromjaximportexport>>>fromjax.shardingimportMesh,NamedSharding>>>fromjax.shardingimportPartitionSpecasP>>>export_devices=jax.local_devices()>>>export_mesh=Mesh(np.array(export_devices),("a",))>>>deff(x):...returnx.T>>>exp=export.export(jax.jit(f))(...jax.ShapeDtypeStruct((4*len(export_devices),),dtype=np.int32,...sharding=NamedSharding(export_mesh,P("a"))))>>>arg=jnp.arange(4*len(export_devices))>>>exp.call(arg)Traceback (most recent call last):ValueError:Exported module f was lowered for 8 devices and is called in a context with 1 devices. This is disallowed because: the module was lowered for more than 1 device.

There are helper functions to shard the inputs for calling an exportedartifacts using a new mesh constructed at the call site:

>>>importjax>>>fromjaximportexport>>>fromjax.shardingimportMesh,NamedSharding>>>fromjax.shardingimportPartitionSpecasP>>>export_devices=jax.local_devices()>>>export_mesh=Mesh(np.array(export_devices),("a",))>>>deff(x):...returnx.T>>>exp=export.export(jax.jit(f))(...jax.ShapeDtypeStruct((4*len(export_devices),),dtype=np.int32,...sharding=NamedSharding(export_mesh,P("a"))))>>># Prepare the mesh for calling `exp`.>>>calling_mesh=Mesh(np.array(export_devices[::-1]),("a",))>>># Shard the arg according to what `exp` expects.>>>arg=jnp.arange(4*len(export_devices))>>>sharded_arg=jax.device_put(arg,exp.in_shardings_jax(calling_mesh)[0])>>>res=exp.call(sharded_arg)

As a special facility, if a function was exported for 1 device and if it contains nosharding annotations, then it can be invoked on an argument of the same shape but shardedon multiple devices, and the compiler will shard the function appropriately:

```python>>>importjax>>>fromjaximportexport>>>fromjax.shardingimportMesh,NamedSharding>>>fromjax.shardingimportPartitionSpecasP>>>deff(x):...returnjnp.cos(x)>>>arg=jnp.arange(4)>>>exp=export.export(jax.jit(f))(arg)>>>exp.in_avals(ShapedArray(int32[4]),)>>>exp.nr_devices1>>># Prepare the mesh for calling `exp`.>>>calling_mesh=Mesh(jax.local_devices()[:4],("b",))>>># Shard the arg according to what `exp` expects.>>>sharded_arg=jax.device_put(arg,...NamedSharding(calling_mesh,P("b")))>>>res=exp.call(sharded_arg)

Calling convention versions#

The JAX export support has evolved over time, e.g., to supporteffects. In order to support compatibility (seecompatibility guarantees)we maintain a calling convention version for eachExported.As of June 2024, all function exported with version 9(the latest, seeall calling convention versions):

>>>fromjaximportexport>>>exp:export.Exported=export.export(jnp.cos)(1.)>>>exp.calling_convention_version10

At any given time, the export APIs may support a rangeof calling convention versions. You can control which calling conventionversion to use using the--jax_export_calling_convention_version flagor theJAX_EXPORT_CALLING_CONVENTION_VERSION environment variable:

>>>fromjaximportexport>>>(export.minimum_supported_calling_convention_version,export.maximum_supported_calling_convention_version)(9, 10)>>>fromjax._srcimportconfig>>>withconfig.jax_export_calling_convention_version(10):...exp=export.export(jnp.cos)(1.)...exp.calling_convention_version10

We reserve the right to remove support forgenerating or consuming calling convention versions older than 6 months.

Module calling convention#

TheExported.mlir_module has amain function that takes an optional firstplatform index argument if the module supports multiple platforms(len(platforms)>1), followed by the token arguments correspondingto the ordered effects, followed by the kept arrayarguments (corresponding tomodule_kept_var_idx andin_avals).The platform index is a i32 or i64 scalar encoding the index of the currentcompilation platform into theplatforms sequence.

Inner functions use a different calling convention: an optionalplatform index argument, optional dimension variable arguments(scalar tensors of type i32 or i64),followed by optional token arguments (in presence of ordered effects),followed by the regular array arguments.The dimension arguments correspond to the dimension variables appearing intheargs_avals, in sorted order of their names.

Consider the lowering of a function with one array argument of typef32[w,2*h], wherew andh are two dimension variables.Assume that we use multi-platform lowering, and we haveone ordered effect. Themain function will be as follows:

      func public main(            platform_index: i32 {jax.global_constant="_platform_index"},            token_in: token,            arg: f32[?, ?]) {         arg_w = hlo.get_dimension_size(arg, 0)         dim1 = hlo.get_dimension_size(arg, 1)         arg_h = hlo.floordiv(dim1, 2)         call _check_shape_assertions(arg)  # See below         token = new_token()         token_out, res = call _wrapped_jax_export_main(platform_index,                                                        arg_h,                                                        arg_w,                                                        token_in,                                                        arg)         return token_out, res      }

The actual computation is in_wrapped_jax_export_main, taking alsothe values ofh andw dimension variables.

The signature of the_wrapped_jax_export_main is:

      func private _wrapped_jax_export_main(          platform_index: i32 {jax.global_constant="_platform_index"},          arg_h: i32 {jax.global_constant="h"},          arg_w: i32 {jax.global_constant="w"},          arg_token: stablehlo.token {jax.token=True},          arg: f32[?, ?]) -> (stablehlo.token, ...)

Prior to calling convention version 9 the calling convention for effects wasdifferent: themain function does not take or return a token. Insteadthe function creates dummy tokens of typei1[0] and passes them to the_wrapped_jax_export_main. The_wrapped_jax_export_maintakes dummy tokens of typei1[0] and will create internally realtokens to pass to the inner functions. The inner functions use realtokens (both before and after calling convention version 9)

Also starting with calling convention version 9, function arguments that containthe platform index or the dimension variable values have ajax.global_constant string attribute whose value is the name of theglobal constant, either_platform_index or a dimension variable name.The global constant name may be empty if it is not known.Some global constant computations use inner functions, e.g., forfloor_divide. The arguments of such functions have ajax.global_constantattribute for all attributes, meaning that the result of the function isalso a global constant.

Note thatmain contains a call to_check_shape_assertions.JAX tracing assumes thatarg.shape[1] is even, and that bothw andhhave values >= 1. We must check these constraints when we invoke themodule. We use a special custom call@shape_assertion that takesa boolean first operand, a stringerror_message attribute that may containformat specifiers{0},{1}, …, and a variadic number of integerscalar operands corresponding to the format specifiers.

       func private _check_shape_assertions(arg: f32[?, ?]) {         # Check that w is >= 1         arg_w = hlo.get_dimension_size(arg, 0)         custom_call @shape_assertion(arg_w >= 1, arg_w,            error_message="Dimension variable 'w' must have integer value >= 1. Found {0}")         # Check that dim1 is even         dim1 = hlo.get_dimension_size(arg, 1)         custom_call @shape_assertion(dim1 % 2 == 0, dim1 % 2,            error_message="Division had remainder {0} when computing the value of 'h')         # Check that h >= 1         arg_h = hlo.floordiv(dim1, 2)         custom_call @shape_assertion(arg_h >= 1, arg_h,            error_message=""Dimension variable 'h' must have integer value >= 1. Found {0}")

Calling convention versions#

We list here a history of the calling convention version numbers:

  • Version 1 used MHLO & CHLO to serialize the code, not supported anymore.

  • Version 2 supports StableHLO & CHLO. Used from October 2022. Not supportedanymore.

  • Version 3 supports platform checking and multiple platforms.Used from February 2023. Not supported anymore.

  • Version 4 supports StableHLO with compatibility guarantees.This is the earliest version at the time of the JAX native serializationlaunch.Used in JAX from March 15, 2023 (cl/516885716). Starting withMarch 28th, 2023 we stopped usingdim_args_spec (cl/520033493).The support for this version was dropped onOctober 17th, 2023 (cl/573858283).

  • Version 5 adds support forcall_tf_graph. This is currently usedfor some specialized use cases. Used in JAX from May 3rd, 2023(cl/529106145).

  • Version 6 adds support for thedisabled_checks attribute. This versionmandates a non-emptyplatforms attribute. Supported by XlaCallModulesince June 7th, 2023 and available in JAX sinceJune 13th, 2023 (JAX 0.4.13).

  • Version 7 adds support forstablehlo.shape_assertion operations andforshape_assertions specified indisabled_checks.SeeErrors in presence of shape polymorphism. Supported by XlaCallModulesince July 12th, 2023 (cl/547482522),available in JAX serialization since July 20th, 2023 (JAX 0.4.14),and the default since August 12th, 2023 (JAX 0.4.15).

  • Version 8 adds support for thejax.uses_shape_polymorphism moduleattribute and enables the shape refinement pass only when theattribute is present. Supported by XlaCallModule since July 21st, 2023(cl/549973693), available in JAX since July 26th, 2023 (JAX 0.4.14),and the default since October 21st, 2023 (JAX 0.4.20).

  • Version 9 adds support for effects.See the docstring forexport.Exported for the precise calling convention.In this calling convention version we also tag the platform index and thedimension variables arguments withjax.global_constant attributes.Supported by XlaCallModule since October 27th, 2023,available in JAX since October 20th, 2023 (JAX 0.4.20),and the default since February 1st, 2024 (JAX 0.4.24).This is the only supported version as of 27th of March, 2024.

  • Version 10 propagate thejax.config.use_shardy_partitioner value toXlaCallModule. Supported by XlaCallModule since May 20th, 2025, andthe default in JAX since July 14th, 2025 (JAX 0.7.0).

Developer documentation#

Debugging#

You can log the exported modules, with somewhat different flags in OSS versusin Google. In OSS you can do the following:

# Log from pythonpythontests/export_test.pyJaxExportTest.test_basic-v=3# Or, log from pytest to /tmp/mylog.txtpytesttests/export_test.py-ktest_basic--log-level=3--log-file=/tmp/mylog.txt

You will see a log line of the form:

I061910:54:18.9787338299482112_export.py:606]ExportedJAXfunction:fun_name=sinversion=9lowering_platforms=('cpu',)disabled_checks=()I061910:54:18.9787678299482112_export.py:607]DefineJAX_DUMP_IR_TOtodumpthemodule.

If you set the environment variableJAX_DUMP_IR_TO to a directory, the exported (and the JIT compiled) HLOmodules will be saved there.

JAX_DUMP_IR_TO=/tmp/export.dumpspytesttests/export_test.py-ktest_basic--log-level=3--log-file=/tmp/mylog.txtINFOabsl:_export.py:606ExportedJAXfunction:fun_name=sinversion=9lowering_platforms=('cpu',)disabled_checks=()INFOabsl:_export.py:607Themodulewasdumpedtojax_ir0_jit_sin_export.mlir.

You will see both the exported modules (named..._export.mlirand the JIT compiled modules (named..._compile.mlir):

$ls-l/tmp/export.dumps/total32-rw-rw-r--@1neculawheel2316Jun1911:04jax_ir0_jit_sin_export.mlir-rw-rw-r--@1neculawheel2279Jun1911:04jax_ir1_jit_sin_compile.mlir-rw-rw-r--@1neculawheel3377Jun1911:04jax_ir2_jit_call_exported_compile.mlir-rw-rw-r--@1neculawheel2333Jun1911:04jax_ir3_jit_my_fun_export.mlir

SetJAX_DEBUG_LOG_MODULES=jax._src.export to enable extra debugging logging.

Ensuring forward and backward compatibility#

This section discusses the process JAX developersshould use to ensure thecompatibility guarantees.

One complication is that external users install JAX and jaxlibin separate packages,and users often end up using an older jaxlib than JAX.We observe that the custom calls live in the jaxlib, and only the jaxlib is relevantfor a consumer of an exported artifact.To simplify the process, we are setting the expectation for external usersthat the compatibility window is defined in terms of jaxlib releases,and it is their responsibility to ensure that they export with a new jaxlibeven if JAX would function with an older version.

Thus, we care only about jaxlib releases.We can start a backward-compatibility deprecation clock when we make a jaxlib release,even if we don’t force it to be the minimum allowed version.

Let’s say that we need to add, delete, or change the semantics of acustom call targetT used by the JAX lowering rules.Here is a possible chronology (for changing custom call targetsthat live in jaxlib):

  1. Day “D - 1”, before the change. Say that the active internal JAX version is0.4.31(the version of the next JAX and jaxlib releases).The JAX lowering rules use a custom callT.

  2. Day “D”, we add the new custom call targetT_NEW.We should create a new custom call target, and clean up the oldtarget roughly after 6 months, rather than updatingT in place:

    • See the examplePR #20997implementing the steps below.

    • We add the custom call targetT_NEW.

    • We change the JAX lowering rules that were previous usingT,to useT_NEW, conditionally as follows:

    fromjax._srcimportconfigfromjax._src.libimportversionasjaxlib_versiondefmy_lowering_rule(ctx:LoweringRuleContext,...):ifctx.is_forward_compat()orjaxlib_version<(0,4,31):# this is the old lowering, using target T, while we# are in forward compatibility mode for T, or we# are in OSS and are using an old jaxlib.returnhlo.custom_call("T",...)else:# This is the new lowering, using target T_NEW, for# when we use a jaxlib with version `>= (0, 4, 31)`# (or when this is internal usage), and also we are# in JIT mode.returnhlo.custom_call("T_NEW",...)
    • Note that the forward compatibility mode is always false in JIT modeor if the user passes--jax_export_ignore_forward_compatibility=true

    • Note that at this point the exports will still not useT_NEW.

  3. This can be done at any time after the previous step, and beforethe next step: Add a backward compatibility test forT_NEW,and addT_NEW to the list of_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE in_export.py.

    • Instructions for adding backwards compatibility tests are at the top ofexport_back_compat_test_util.py.

    • An example is inPR #29488.

    • Note that if you do this before the next step, the exporting will still notuse theT_NEW lowering, and you have to addwithconfig.export_ignore_forward_compatibility(True): around the call toself.run_one_test. This can be removed when you actually get to step 4.

    • You may also need to enable the test only for new versions of jaxlib.

  4. Day “D + 21” (end of forward compatibility window; can be even later than 21 days):We remove theforward_compat_mode in the lowering code, so now exportingwill start using the new custom call targetT_NEW as long as we are using a newjaxlib.

  5. Day “RELEASE > D” (the first JAX release date afterD, when we release version0.4.31):we start the clock for the 6 months backwards compatibility.Note that this is relevant only ifT is among the custom call targets for whichwe already guarantee stability, i.e., are listed in_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE.

    • IfRELEASE is in the forward compatibility window[D,D+21] and ifwe makeRELEASE the minimum allowed jaxlib version then we canremove thejaxlib_version<(0,4,31) conditional in theJIT branch.

  6. Day “RELEASE + 180” (end of backward compatibility window,can be even later than 180 days): By now, we must have bumpedthe minimum jaxlib so that the lowering conditionaljaxlib_version<(0,4,31)was already removed and JAX lowering cannot generate custom calls toT.

    • We remove the C++ implementation of the old custom call targetT.

    • We remove also the backwards compatibility test forT

Migration guide from jax.experimental.export#

On June 18, 2024 (JAX version 0.4.30)we deprecated thejax.experimental.export APIsin favor ofjax.export APIs. There have been some minor changes:

  • jax.experimental.export.export:

    • The old function used to allow any Python callable, or the result ofjax.jit. Now only the latter is accepted. You have to manually applyjax.jit to the function to export before callingexport.

    • The oldlowering_parameters kwarg is now namedplatforms

  • jax.experimental.export.default_lowering_platform() is nowatjax.export.default_export_platform().

  • jax.experimental.export.call is now a method of thejax.export.Exported object.Instead ofexport.call(exp) you should useexp.call.

  • jax.experimental.export.serialize is now a method of thejax.export.Exportedobject. Instead ofexport.serialize(exp) you should useexp.serialize().

  • The configuration flag--jax-serialization-version is deprecated.Use--jax-export-calling-convention-version.

  • The valuejax.experimental.export.minimum_supported_serialization_versionis now atjax.export.minimum_supported_calling_convention_version.

  • The following fields ofjax.export.Exported have been renamed

    • uses_shape_polymorphism is nowuses_global_constants

    • mlir_module_serialization_version is nowcalling_convention_version

    • lowering_platforms is nowplatforms.


[8]ページ先頭

©2009-2026 Movatter.jp