Exporting and serializing staged-out computations
Contents
Exporting and serializing staged-out computations#
TheAhead-of-time lowering and compilation APIs produceobjects that can be used for debugging or for compilation andexecution in the same process.Sometimes you want to serialize a lowered JAX function forcompilation and execution in a separate process, perhapsat a later time. This would allow you to:
compile and execute the function in another process or machinewithout requiring access to the JAX program,and without having to repeat the staging-out and lowering, e.g.,in an inference system.
trace and lower a function on a machine that does not have accessto the accelerator for which you want to later compile and executethe function.
archive a snapshot of a JAX function, e.g., to be able toreproduce later your results.Note: check out thecompatibilityguarantees for this use case.
For more details see thejax.export API reference.
Here is an example:
>>>importre>>>importnumpyasnp>>>importjax>>>fromjaximportexport>>>deff(x):return2*x*x>>>exported:export.Exported=export.export(jax.jit(f))(...jax.ShapeDtypeStruct((),np.float32))>>># You can inspect the Exported object>>>exported.fun_name'f'>>>exported.in_avals(ShapedArray(float32[]),)>>>print(re.search(r".*@main.*",exported.mlir_module()).group(0)) func.func public @main(%arg0: tensor<f32> loc("x")) -> (tensor<f32> {jax.result_info = "result"}) {>>># And you can serialize the Exported to a bytearray.>>>serialized:bytearray=exported.serialize()>>># The serialized function can later be rehydrated and called from>>># another JAX computation, possibly in another process.>>>rehydrated_exp:export.Exported=export.deserialize(serialized)>>>rehydrated_exp.in_avals(ShapedArray(float32[]),)>>>defcallee(y):...return3.*rehydrated_exp.call(y*4.)>>>callee(1.)Array(96., dtype=float32)
Serialization is broken down into two stages:
exporting to produce an
jax.export.Exportedobject that containsthe StableHLO for the lowered function along with the metadata necessary tocall it from another JAX function. We have plans to add code to generateExportedobjects from TensorFlow, and to useExportedobjects fromTensorFlow and PyTorch.the actual serialization to a byte array using the flatbuffers format.SeeInteroperation with TensorFlow foran alternative serialization to TensorFlow graph that can be usedfor interoperation with TensorFlow.
Support for reverse-mode AD#
Serialization can optionally support higher-order reverse-mode AD. This is doneby serializing thejax.vjp() of the primal function along with the primal function,up to a user-specified order (default is 0, meaning that the rehydratedfunction cannot be differentiated):
>>>importjax>>>fromjaximportexport>>>fromtypingimportCallable>>>deff(x):return7*x*x*x>>># Serialize 3 levels of VJP along with the primal function>>>blob:bytearray=export.export(jax.jit(f))(1.).serialize(vjp_order=3)>>>rehydrated_f:Callable=export.deserialize(blob).call>>>rehydrated_f(0.1)# 7 * 0.1^3Array(0.007, dtype=float32)>>>jax.grad(rehydrated_f)(0.1)# 7*3 * 0.1^2Array(0.21000001, dtype=float32)>>>jax.grad(jax.grad(rehydrated_f))(0.1)# 7*3*2 * 0.1Array(4.2, dtype=float32)>>>jax.grad(jax.grad(jax.grad(rehydrated_f)))(0.1)# 7*3*2Array(42., dtype=float32)>>>jax.grad(jax.grad(jax.grad(jax.grad(rehydrated_f))))(0.1)Traceback (most recent call last):ValueError:No VJP is available
Note that the VJP function is computed lazily while serializing,when the JAX program is still available.This means that it respects all features of JAX VJP,e.g.,jax.custom_vjp() andjax.remat().
Note that the rehydrated function does not support any othertransformations, e.g., forward-mode AD (jvp), orjax.vmap().
Compatibility guarantees#
You should not use the raw StableHLO that is obtained from just lowering(jax.jit(f).lower(1.).compiler_ir())for archival and for compilation in another process, for several reasons.
First, the compilation may use a different version of the compiler, supporting adifferent version of StableHLO. Thejax.export module takescare of this by using theportable-artifact feature of StableHLOto deal with the possible evolution of the StableHLO opset.
Compatibility guarantees for custom calls#
Second, the raw StableHLO may contain custom calls referencing C++functions.JAX uses custom calls for lowering of a small number of primitives,e.g., linear algebra primitives, sharding annotations, or Pallas kernels.These do not fall under the compatibility guarantees for StableHLO.The C++ implementations of these functions change rarely, but they can change.
jax.export makes the following export compatibility guarantees:A JAX exported artifact can be compiled and executed by a compiler andJAX runtime system that are:
up to 6 months newer than the version of JAX used for exporting(we say that JAX export offers6 months backward compatibility).This is useful if we want to archive the exported artifact to be compiled and executed later.
up to 3 weeks older than the version of JAX used for exporting(we say that JAX export offers3 weeks forward compatibility).This is useful if we want to compile and run an exported artifact with aconsumer that was built and deployed before the export, e.g.,an inference system that is already deployed when the exporting is done.
(The particular compatibility window lengths are the same that JAXpromised for jax2tf,and are based onTensorFlow Compatibility.The terminology “backward compatibility” is from the perspective of the consumer,e.g., the inference system.)
Whatmatters is when the exporting and consuming components were built,not the time when the exporting and the compilation happen.For external JAX users, it ispossible to run JAX and jaxlib at different versions;what matters is when the jaxlib release was built.
To reduce chances of incompatibility, internal JAX users should:
rebuild and redeploy consumer systems as frequently as possible.
and external users should:
run the exporting and consumer systems with the same version of jaxlib, whenever possible, and
export for archivalwith the latest released version of jaxlib.
The compatibility guarantees do not apply if you bypass thejax.export APIsto obtain the StableHLO code.
In order to ensure forward compatibility, when we change the JAX lowering rulesto use a new custom call target, JAX will refrain for 3 weeks to use the newtarget. To use the latest lowering rules, you can pass the--jax_export_ignore_forward_compatibility=1 configuration flagor theJAX_EXPORT_IGNORE_FORWARD_COMPATIBILITY=1 environment variable.
Only a subset of custom calls are guaranteed stable and havecompatibility guarantees (see list).We continuouslyadd more custom call targets to the allowed list along with backwardscompatibility tests. If you try to serializecode that invokes other custom call targets you will get an errorduring exporting.
If you want to disable this safety check for a specific custom call,e.g., with targetmy_target, you can addexport.DisabledSafetyCheck.custom_call("my_target") to thedisabled_checks parameter of theexport method,as in the following example:
>>>importjax>>>fromjaximportexport>>>fromjaximportlax>>>fromjax._srcimportcore>>>fromjax._src.interpretersimportmlir>>># Define a new primitive backed by a custom call>>>new_prim=core.Primitive("new_prim")>>>_=new_prim.def_abstract_eval(lambdax:x)>>>_=mlir.register_lowering(new_prim,lambdactx,o:mlir.custom_call("my_new_prim",operands=[o],result_types=[o.type]).results)>>>print(jax.jit(new_prim.bind).lower(1.).compiler_ir())module @jit_bind attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<f32>) -> (tensor<f32> {jax.result_info = "result"}) { %0 = stablehlo.custom_call @my_new_prim(%arg0) {api_version = 2 : i32, backend_config = ""} : (tensor<f32>) -> tensor<f32> return %0 : tensor<f32> }}>>># If we try to export, we get an error>>>export.export(jax.jit(new_prim.bind))(1.)Traceback (most recent call last):ValueError:Cannot serialize code with custom calls whose targets have no compatibility guarantees: my_new_bind>>># We can avoid the error if we pass a `DisabledSafetyCheck.custom_call`>>>exp=export.export(...jax.jit(new_prim.bind),...disabled_checks=[export.DisabledSafetyCheck.custom_call("my_new_prim")])(1.)
SeeEnsuring forward and backward compatibility for developer information regardingensuring compatibility.
Cross-platform and multi-platform export#
JAX lowering is platform specific for a small number of JAX primitives.By default, the code is lowered and exported for the acceleratorpresent on the exporting machine:
>>>fromjaximportexport>>>export.default_export_platform()'cpu'
There is a safety check that will raise an error when trying to compileanExported object on a machine that does not have the acceleratorfor which the code was exported.
You can specify explicitly for what platforms the code should be exported.This allows you to specify a different accelerator than you haveavailable at export time,and it even allows you to specify multi-platform export toobtain anExported object that can be compiled and executedon multiple platforms.
>>>importjax>>>fromjaximportexport>>>fromjaximportlax>>># You can specify the export platform, e.g., `tpu`, `cpu`, `cuda`, `rocm`>>># even if the current machine does not have that accelerator.>>>exp=export.export(jax.jit(lax.cos),platforms=['tpu'])(1.)>>># But you will get an error if you try to compile `exp`>>># on a machine that does not have TPUs.>>>exp.call(1.)Traceback (most recent call last):ValueError:Function 'cos' was lowered for platforms '('tpu',)' but it is used on '('cpu',)'.>>># We can avoid the error if we pass a `DisabledSafetyCheck.platform`>>># parameter to `export`, e.g., because you have reasons to believe>>># that the code lowered will run adequately on the current>>># compilation platform (which is the case for `cos` in this>>># example):>>>exp_unsafe=export.export(jax.jit(lax.cos),...platforms=['tpu'],...disabled_checks=[export.DisabledSafetyCheck.platform()])(1.)>>>exp_unsafe.call(1.)Array(0.5403023, dtype=float32, weak_type=True)# and similarly with multi-platform lowering>>>exp_multi=export.export(jax.jit(lax.cos),...platforms=['tpu','cpu','cuda'])(1.)>>>exp_multi.call(1.)Array(0.5403023, dtype=float32, weak_type=True)
For multi-platform export, the StableHLO will contain multiplelowerings but only for those primitives that require it, so theresulting module size should be only marginally larger than thesize of a module with default export.As an extreme case, when serializing a module without anyprimitives with platform-specific lowering, you will getthe same StableHLO as for the single-platform export.
>>>importjax>>>fromjaximportexport>>>fromjaximportlax>>># A largish function>>>deff(x):...foriinrange(1000):...x=jnp.cos(x)...returnx>>>exp_single=export.export(jax.jit(f))(1.)>>>len(exp_single.mlir_module_serialized)9220>>>exp_multi=export.export(jax.jit(f),...platforms=["cpu","tpu","cuda"])(1.)>>>len(exp_multi.mlir_module_serialized)9282
Shape polymorphic export#
When used in JIT mode, JAX will trace and lower a function separatelyfor each combination of input shapes. When exporting, it is possiblein some cases to use dimension variables for some input dimensionsin order to obtain an exported artifact that can be used with multiplecombinations of input shapes.
See theShape polymorphism documentation.
Device-polymorphic export#
An exported artifact may contain sharding annotations for inputs,outputs and for some intermediates, but these annotations do not referdirectly to the actual physical devices that existed at exporting time.Instead, the sharding annotations refer to logical devices. Thismeans that you can compile and run the exported artifacts on differentphysical devices that were used for exporting.
The cleanest way to achieve a device-polymorphic export is touse shardings constructed with ajax.sharding.AbstractMesh,which contains only the mesh shape and axis names. But,you can achieve the same results if you use shardingsconstructed for a mesh with concrete devices, since the actualdevices in the mesh are ignored for tracing and lowering:
>>>importjax>>>fromjaximportexport>>>fromjax.shardingimportAbstractMesh,Mesh,NamedSharding>>>fromjax.shardingimportPartitionSpecasP>>>>>># Use an AbstractMesh for exporting>>>export_mesh=AbstractMesh((4,),("a",))>>>deff(x):...returnx.T>>>exp=export.export(jax.jit(f))(...jax.ShapeDtypeStruct((32,),dtype=np.int32,...sharding=NamedSharding(export_mesh,P("a"))))>>># `exp` knows for how many devices it was exported.>>>exp.nr_devices4>>># and it knows the shardings for the inputs. These will be applied>>># when the exported is called.>>>exp.in_shardings_hlo({devices=[4]<=[4]},)>>># You can also use a concrete set of devices for exporting>>>concrete_devices=jax.local_devices()[:4]>>>concrete_mesh=Mesh(concrete_devices,("a",))>>>exp2=export.export(jax.jit(f))(...jax.ShapeDtypeStruct((32,),dtype=np.int32,...sharding=NamedSharding(concrete_mesh,P("a"))))>>># You can expect the same results>>>assertexp.in_shardings_hlo==exp2.in_shardings_hlo>>># When you call an Exported, you must use a concrete set of devices>>>arg=jnp.arange(8*4)>>>res1=exp.call(jax.device_put(arg,...NamedSharding(concrete_mesh,P("a"))))>>># Check out the first 2 shards of the result>>>[f"device={s.device} index={s.index}"forsinres1.addressable_shards[:2]]['device=TFRT_CPU_0 index=(slice(0, 8, None),)', 'device=TFRT_CPU_1 index=(slice(8, 16, None),)']>>># We can call `exp` with some other 4 devices and another>>># mesh with a different shape, as long as the number of devices is>>># the same.>>>other_mesh=Mesh(np.array(jax.local_devices()[2:6]).reshape((2,2)),("b","c"))>>>res2=exp.call(jax.device_put(arg,...NamedSharding(other_mesh,P("b"))))>>># Check out the first 2 shards of the result. Notice that the output is>>># sharded similarly; this means that the input was resharded according to the>>># exp.in_shardings.>>>[f"device={s.device} index={s.index}"forsinres2.addressable_shards[:2]]['device=TFRT_CPU_2 index=(slice(0, 8, None),)', 'device=TFRT_CPU_3 index=(slice(8, 16, None),)']
It is an error to try to invoke an exported artifact with a different numberof devices than it was exported for:
>>>importjax>>>fromjaximportexport>>>fromjax.shardingimportMesh,NamedSharding>>>fromjax.shardingimportPartitionSpecasP>>>export_devices=jax.local_devices()>>>export_mesh=Mesh(np.array(export_devices),("a",))>>>deff(x):...returnx.T>>>exp=export.export(jax.jit(f))(...jax.ShapeDtypeStruct((4*len(export_devices),),dtype=np.int32,...sharding=NamedSharding(export_mesh,P("a"))))>>>arg=jnp.arange(4*len(export_devices))>>>exp.call(arg)Traceback (most recent call last):ValueError:Exported module f was lowered for 8 devices and is called in a context with 1 devices. This is disallowed because: the module was lowered for more than 1 device.
There are helper functions to shard the inputs for calling an exportedartifacts using a new mesh constructed at the call site:
>>>importjax>>>fromjaximportexport>>>fromjax.shardingimportMesh,NamedSharding>>>fromjax.shardingimportPartitionSpecasP>>>export_devices=jax.local_devices()>>>export_mesh=Mesh(np.array(export_devices),("a",))>>>deff(x):...returnx.T>>>exp=export.export(jax.jit(f))(...jax.ShapeDtypeStruct((4*len(export_devices),),dtype=np.int32,...sharding=NamedSharding(export_mesh,P("a"))))>>># Prepare the mesh for calling `exp`.>>>calling_mesh=Mesh(np.array(export_devices[::-1]),("a",))>>># Shard the arg according to what `exp` expects.>>>arg=jnp.arange(4*len(export_devices))>>>sharded_arg=jax.device_put(arg,exp.in_shardings_jax(calling_mesh)[0])>>>res=exp.call(sharded_arg)
As a special facility, if a function was exported for 1 device and if it contains nosharding annotations, then it can be invoked on an argument of the same shape but shardedon multiple devices, and the compiler will shard the function appropriately:
```python>>>importjax>>>fromjaximportexport>>>fromjax.shardingimportMesh,NamedSharding>>>fromjax.shardingimportPartitionSpecasP>>>deff(x):...returnjnp.cos(x)>>>arg=jnp.arange(4)>>>exp=export.export(jax.jit(f))(arg)>>>exp.in_avals(ShapedArray(int32[4]),)>>>exp.nr_devices1>>># Prepare the mesh for calling `exp`.>>>calling_mesh=Mesh(jax.local_devices()[:4],("b",))>>># Shard the arg according to what `exp` expects.>>>sharded_arg=jax.device_put(arg,...NamedSharding(calling_mesh,P("b")))>>>res=exp.call(sharded_arg)
Calling convention versions#
The JAX export support has evolved over time, e.g., to supporteffects. In order to support compatibility (seecompatibility guarantees)we maintain a calling convention version for eachExported.As of June 2024, all function exported with version 9(the latest, seeall calling convention versions):
>>>fromjaximportexport>>>exp:export.Exported=export.export(jnp.cos)(1.)>>>exp.calling_convention_version10
At any given time, the export APIs may support a rangeof calling convention versions. You can control which calling conventionversion to use using the--jax_export_calling_convention_version flagor theJAX_EXPORT_CALLING_CONVENTION_VERSION environment variable:
>>>fromjaximportexport>>>(export.minimum_supported_calling_convention_version,export.maximum_supported_calling_convention_version)(9, 10)>>>fromjax._srcimportconfig>>>withconfig.jax_export_calling_convention_version(10):...exp=export.export(jnp.cos)(1.)...exp.calling_convention_version10
We reserve the right to remove support forgenerating or consuming calling convention versions older than 6 months.
Module calling convention#
TheExported.mlir_module has amain function that takes an optional firstplatform index argument if the module supports multiple platforms(len(platforms)>1), followed by the token arguments correspondingto the ordered effects, followed by the kept arrayarguments (corresponding tomodule_kept_var_idx andin_avals).The platform index is a i32 or i64 scalar encoding the index of the currentcompilation platform into theplatforms sequence.
Inner functions use a different calling convention: an optionalplatform index argument, optional dimension variable arguments(scalar tensors of type i32 or i64),followed by optional token arguments (in presence of ordered effects),followed by the regular array arguments.The dimension arguments correspond to the dimension variables appearing intheargs_avals, in sorted order of their names.
Consider the lowering of a function with one array argument of typef32[w,2*h], wherew andh are two dimension variables.Assume that we use multi-platform lowering, and we haveone ordered effect. Themain function will be as follows:
func public main( platform_index: i32 {jax.global_constant="_platform_index"}, token_in: token, arg: f32[?, ?]) { arg_w = hlo.get_dimension_size(arg, 0) dim1 = hlo.get_dimension_size(arg, 1) arg_h = hlo.floordiv(dim1, 2) call _check_shape_assertions(arg) # See below token = new_token() token_out, res = call _wrapped_jax_export_main(platform_index, arg_h, arg_w, token_in, arg) return token_out, res }The actual computation is in_wrapped_jax_export_main, taking alsothe values ofh andw dimension variables.
The signature of the_wrapped_jax_export_main is:
func private _wrapped_jax_export_main( platform_index: i32 {jax.global_constant="_platform_index"}, arg_h: i32 {jax.global_constant="h"}, arg_w: i32 {jax.global_constant="w"}, arg_token: stablehlo.token {jax.token=True}, arg: f32[?, ?]) -> (stablehlo.token, ...)Prior to calling convention version 9 the calling convention for effects wasdifferent: themain function does not take or return a token. Insteadthe function creates dummy tokens of typei1[0] and passes them to the_wrapped_jax_export_main. The_wrapped_jax_export_maintakes dummy tokens of typei1[0] and will create internally realtokens to pass to the inner functions. The inner functions use realtokens (both before and after calling convention version 9)
Also starting with calling convention version 9, function arguments that containthe platform index or the dimension variable values have ajax.global_constant string attribute whose value is the name of theglobal constant, either_platform_index or a dimension variable name.The global constant name may be empty if it is not known.Some global constant computations use inner functions, e.g., forfloor_divide. The arguments of such functions have ajax.global_constantattribute for all attributes, meaning that the result of the function isalso a global constant.
Note thatmain contains a call to_check_shape_assertions.JAX tracing assumes thatarg.shape[1] is even, and that bothw andhhave values >= 1. We must check these constraints when we invoke themodule. We use a special custom call@shape_assertion that takesa boolean first operand, a stringerror_message attribute that may containformat specifiers{0},{1}, …, and a variadic number of integerscalar operands corresponding to the format specifiers.
func private _check_shape_assertions(arg: f32[?, ?]) { # Check that w is >= 1 arg_w = hlo.get_dimension_size(arg, 0) custom_call @shape_assertion(arg_w >= 1, arg_w, error_message="Dimension variable 'w' must have integer value >= 1. Found {0}") # Check that dim1 is even dim1 = hlo.get_dimension_size(arg, 1) custom_call @shape_assertion(dim1 % 2 == 0, dim1 % 2, error_message="Division had remainder {0} when computing the value of 'h') # Check that h >= 1 arg_h = hlo.floordiv(dim1, 2) custom_call @shape_assertion(arg_h >= 1, arg_h, error_message=""Dimension variable 'h' must have integer value >= 1. Found {0}")Calling convention versions#
We list here a history of the calling convention version numbers:
Version 1 used MHLO & CHLO to serialize the code, not supported anymore.
Version 2 supports StableHLO & CHLO. Used from October 2022. Not supportedanymore.
Version 3 supports platform checking and multiple platforms.Used from February 2023. Not supported anymore.
Version 4 supports StableHLO with compatibility guarantees.This is the earliest version at the time of the JAX native serializationlaunch.Used in JAX from March 15, 2023 (cl/516885716). Starting withMarch 28th, 2023 we stopped using
dim_args_spec(cl/520033493).The support for this version was dropped onOctober 17th, 2023 (cl/573858283).Version 5 adds support for
call_tf_graph. This is currently usedfor some specialized use cases. Used in JAX from May 3rd, 2023(cl/529106145).Version 6 adds support for the
disabled_checksattribute. This versionmandates a non-emptyplatformsattribute. Supported by XlaCallModulesince June 7th, 2023 and available in JAX sinceJune 13th, 2023 (JAX 0.4.13).Version 7 adds support for
stablehlo.shape_assertionoperations andforshape_assertionsspecified indisabled_checks.SeeErrors in presence of shape polymorphism. Supported by XlaCallModulesince July 12th, 2023 (cl/547482522),available in JAX serialization since July 20th, 2023 (JAX 0.4.14),and the default since August 12th, 2023 (JAX 0.4.15).Version 8 adds support for the
jax.uses_shape_polymorphismmoduleattribute and enables the shape refinement pass only when theattribute is present. Supported by XlaCallModule since July 21st, 2023(cl/549973693), available in JAX since July 26th, 2023 (JAX 0.4.14),and the default since October 21st, 2023 (JAX 0.4.20).Version 9 adds support for effects.See the docstring for
export.Exportedfor the precise calling convention.In this calling convention version we also tag the platform index and thedimension variables arguments withjax.global_constantattributes.Supported by XlaCallModule since October 27th, 2023,available in JAX since October 20th, 2023 (JAX 0.4.20),and the default since February 1st, 2024 (JAX 0.4.24).This is the only supported version as of 27th of March, 2024.Version 10 propagate the
jax.config.use_shardy_partitionervalue toXlaCallModule. Supported by XlaCallModule since May 20th, 2025, andthe default in JAX since July 14th, 2025 (JAX 0.7.0).
Developer documentation#
Debugging#
You can log the exported modules, with somewhat different flags in OSS versusin Google. In OSS you can do the following:
# Log from pythonpythontests/export_test.pyJaxExportTest.test_basic-v=3# Or, log from pytest to /tmp/mylog.txtpytesttests/export_test.py-ktest_basic--log-level=3--log-file=/tmp/mylog.txt
You will see a log line of the form:
I061910:54:18.9787338299482112_export.py:606]ExportedJAXfunction:fun_name=sinversion=9lowering_platforms=('cpu',)disabled_checks=()I061910:54:18.9787678299482112_export.py:607]DefineJAX_DUMP_IR_TOtodumpthemodule.
If you set the environment variableJAX_DUMP_IR_TO to a directory, the exported (and the JIT compiled) HLOmodules will be saved there.
JAX_DUMP_IR_TO=/tmp/export.dumpspytesttests/export_test.py-ktest_basic--log-level=3--log-file=/tmp/mylog.txtINFOabsl:_export.py:606ExportedJAXfunction:fun_name=sinversion=9lowering_platforms=('cpu',)disabled_checks=()INFOabsl:_export.py:607Themodulewasdumpedtojax_ir0_jit_sin_export.mlir.
You will see both the exported modules (named..._export.mlirand the JIT compiled modules (named..._compile.mlir):
$ls-l/tmp/export.dumps/total32-rw-rw-r--@1neculawheel2316Jun1911:04jax_ir0_jit_sin_export.mlir-rw-rw-r--@1neculawheel2279Jun1911:04jax_ir1_jit_sin_compile.mlir-rw-rw-r--@1neculawheel3377Jun1911:04jax_ir2_jit_call_exported_compile.mlir-rw-rw-r--@1neculawheel2333Jun1911:04jax_ir3_jit_my_fun_export.mlir
SetJAX_DEBUG_LOG_MODULES=jax._src.export to enable extra debugging logging.
Ensuring forward and backward compatibility#
This section discusses the process JAX developersshould use to ensure thecompatibility guarantees.
One complication is that external users install JAX and jaxlibin separate packages,and users often end up using an older jaxlib than JAX.We observe that the custom calls live in the jaxlib, and only the jaxlib is relevantfor a consumer of an exported artifact.To simplify the process, we are setting the expectation for external usersthat the compatibility window is defined in terms of jaxlib releases,and it is their responsibility to ensure that they export with a new jaxlibeven if JAX would function with an older version.
Thus, we care only about jaxlib releases.We can start a backward-compatibility deprecation clock when we make a jaxlib release,even if we don’t force it to be the minimum allowed version.
Let’s say that we need to add, delete, or change the semantics of acustom call targetT used by the JAX lowering rules.Here is a possible chronology (for changing custom call targetsthat live in jaxlib):
Day “D - 1”, before the change. Say that the active internal JAX version is
0.4.31(the version of the next JAX and jaxlib releases).The JAX lowering rules use a custom callT.Day “D”, we add the new custom call target
T_NEW.We should create a new custom call target, and clean up the oldtarget roughly after 6 months, rather than updatingTin place:See the examplePR #20997implementing the steps below.
We add the custom call target
T_NEW.We change the JAX lowering rules that were previous using
T,to useT_NEW, conditionally as follows:
fromjax._srcimportconfigfromjax._src.libimportversionasjaxlib_versiondefmy_lowering_rule(ctx:LoweringRuleContext,...):ifctx.is_forward_compat()orjaxlib_version<(0,4,31):# this is the old lowering, using target T, while we# are in forward compatibility mode for T, or we# are in OSS and are using an old jaxlib.returnhlo.custom_call("T",...)else:# This is the new lowering, using target T_NEW, for# when we use a jaxlib with version `>= (0, 4, 31)`# (or when this is internal usage), and also we are# in JIT mode.returnhlo.custom_call("T_NEW",...)
Note that the forward compatibility mode is always false in JIT modeor if the user passes
--jax_export_ignore_forward_compatibility=trueNote that at this point the exports will still not use
T_NEW.
This can be done at any time after the previous step, and beforethe next step: Add a backward compatibility test for
T_NEW,and addT_NEWto the list of_CUSTOM_CALL_TARGETS_GUARANTEED_STABLEin_export.py.Instructions for adding backwards compatibility tests are at the top ofexport_back_compat_test_util.py.
An example is inPR #29488.
Note that if you do this before the next step, the exporting will still notuse the
T_NEWlowering, and you have to addwithconfig.export_ignore_forward_compatibility(True):around the call toself.run_one_test. This can be removed when you actually get to step 4.You may also need to enable the test only for new versions of jaxlib.
Day “D + 21” (end of forward compatibility window; can be even later than 21 days):We remove the
forward_compat_modein the lowering code, so now exportingwill start using the new custom call targetT_NEWas long as we are using a newjaxlib.Day “RELEASE > D” (the first JAX release date after
D, when we release version0.4.31):we start the clock for the 6 months backwards compatibility.Note that this is relevant only ifTis among the custom call targets for whichwe already guarantee stability, i.e., are listed in_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE.If
RELEASEis in the forward compatibility window[D,D+21]and ifwe makeRELEASEthe minimum allowed jaxlib version then we canremove thejaxlib_version<(0,4,31)conditional in theJIT branch.
Day “RELEASE + 180” (end of backward compatibility window,can be even later than 180 days): By now, we must have bumpedthe minimum jaxlib so that the lowering conditional
jaxlib_version<(0,4,31)was already removed and JAX lowering cannot generate custom calls toT.We remove the C++ implementation of the old custom call target
T.We remove also the backwards compatibility test for
T
Migration guide from jax.experimental.export#
On June 18, 2024 (JAX version 0.4.30)we deprecated thejax.experimental.export APIsin favor ofjax.export APIs. There have been some minor changes:
jax.experimental.export.export:The old function used to allow any Python callable, or the result of
jax.jit. Now only the latter is accepted. You have to manually applyjax.jitto the function to export before callingexport.The old
lowering_parameterskwarg is now namedplatforms
jax.experimental.export.default_lowering_platform()is nowatjax.export.default_export_platform().jax.experimental.export.callis now a method of thejax.export.Exportedobject.Instead ofexport.call(exp)you should useexp.call.jax.experimental.export.serializeis now a method of thejax.export.Exportedobject. Instead ofexport.serialize(exp)you should useexp.serialize().The configuration flag
--jax-serialization-versionis deprecated.Use--jax-export-calling-convention-version.The value
jax.experimental.export.minimum_supported_serialization_versionis now atjax.export.minimum_supported_calling_convention_version.The following fields of
jax.export.Exportedhave been renameduses_shape_polymorphismis nowuses_global_constantsmlir_module_serialization_versionis nowcalling_convention_versionlowering_platformsis nowplatforms.
