jax.export module
Contents
jax.export module#
jax.export is a library for exporting and serializing JAX functionsfor persistent archival.
See theExporting and serialization documentation.
Classes#
- classjax.export.Exported(fun_name,in_tree,in_avals,out_tree,out_avals,_has_named_shardings,_in_named_shardings,_out_named_shardings,in_shardings_hlo,out_shardings_hlo,nr_devices,platforms,ordered_effects,unordered_effects,disabled_safety_checks,mlir_module_serialized,calling_convention_version,module_kept_var_idx,uses_global_constants,_get_vjp)[source]#
A JAX function lowered to StableHLO.
- Parameters:
fun_name (str)
in_tree (tree_util.PyTreeDef)
in_avals (tuple[core.ShapedArray,...])
out_tree (tree_util.PyTreeDef)
out_avals (tuple[core.ShapedArray,...])
_has_named_shardings (bool)
_in_named_shardings (tuple[NamedSharding |None,...])
_out_named_shardings (tuple[NamedSharding |None,...])
in_shardings_hlo (tuple[HloSharding |None,...])
out_shardings_hlo (tuple[HloSharding |None,...])
nr_devices (int)
ordered_effects (tuple[effects.Effect,...])
unordered_effects (tuple[effects.Effect,...])
disabled_safety_checks (Sequence[DisabledSafetyCheck])
mlir_module_serialized (bytes)
calling_convention_version (int)
uses_global_constants (bool)
- in_tree#
a PyTreeDef describing the tuple (args, kwargs) of the lowered JAXfunction. The actual lowering does not depend on the
in_tree, but thiscan be used to invoke the exported function using the same argumentstructure.- Type:
tree_util.PyTreeDef
- in_avals#
the flat tuple of input abstract values. May contain dimensionexpressions in the shapes.
- Type:
tuple[core.ShapedArray, …]
- out_tree#
a PyTreeDef describing the result of the lowered JAX function.
- Type:
tree_util.PyTreeDef
- out_avals#
the flat tuple of output abstract values. May contain dimensionexpressions in the shapes, with dimension variables among those in
in_avals. Note that when the out_shardings are not specified foran output, theout_avals.sharding.spec forAuto axes may beNoneeven if after compilation the compiler may pick a non-replicatedsharding.Seehttps://docs.jax.dev/en/latest/notebooks/explicit-sharding.html#concrete-array-shardings-can-mention-auto-mesh-axisfor more details.- Type:
tuple[core.ShapedArray, …]
- in_shardings_hlo#
the flattened input shardings, a sequence as longas
in_avals.Nonemeans unspecified sharding.Note that these do not include the mesh or the actual devices used inthe mesh, and in general you should avoid using this field directly.Seein_shardings_jaxfor a way to turn theseinto sharding specification that can be used with JAX APIs.- Type:
tuple[HloSharding | None, …]
- out_shardings_hlo#
the flattened output shardings, a sequence as longas
out_avals.Nonemeans unspecified sharding.Note that these do not include the mesh or the actual devices used inthe mesh, and in general you should avoid using this field directly.Seeout_shardings_jaxfor a way to turn theseinto sharding specification that can be used with JAX APIs.- Type:
tuple[HloSharding | None, …]
- platforms#
a tuple containing the platforms for which the function shouldbe exported. The set of platforms in JAX is open-ended; users canadd platforms. JAX built-in platforms are: ‘tpu’, ‘cpu’, ‘cuda’, ‘rocm’.Seehttps://docs.jax.dev/en/latest/export/export.html#cross-platform-and-multi-platform-export.
- ordered_effects#
the ordered effects present in the serialized module.This is present from serialization version 9. Seehttps://docs.jax.dev/en/latest/export/export.html#module-calling-conventionfor the calling convention in presence of ordered effects.
- Type:
tuple[effects.Effect, …]
- unordered_effects#
the unordered effects present in the serialized module.This is present from serialization version 9.
- Type:
tuple[effects.Effect, …]
- calling_convention_version#
a version number for the callingconvention of the exported module.See more versioning details athttps://docs.jax.dev/en/latest/export/export.html#calling-convention-versions.
- Type:
- module_kept_var_idx#
the sorted indices of the arguments amongin_avals thatmust be passed to the module. The other arguments have been droppedbecause they are not used.
- uses_global_constants#
whether the
mlir_module_serializeduses shapepolymorphism or multi-platform export.This may be becausein_avalscontains dimensionvariables, or due to inner calls of Exported modules that havedimension variables or platform index arguments. Such modules needshape refinement before XLA compilation.- Type:
- disabled_safety_checks#
a list of descriptors of safety checks that have beendisabled at export time. See docstring for
DisabledSafetyCheck.- Type:
Sequence[DisabledSafetyCheck]
- _get_vjp#
an optional function that takes the current exported function andreturns the exported VJP function.The VJP function takes a flat list of arguments,starting with the primal arguments and followed by a cotangent argumentfor each primal output. It returns a tuple with the cotangentscorresponding to the flattened primal inputs.
DO NOT RELY directly on fields whose name starts with ‘_’. They will change.
See a description of the calling convention for the
mlir_module()method athttps://docs.jax.dev/en/latest/export/export.html#module-calling-convention.- call(*args,**kwargs)[source]#
Call an exported function from a JAX program.
- Parameters:
args – the positional arguments to pass to the exported function. Thisshould be a pytree of arrays with the same pytree structure as thearguments for which the function was exported.
kwargs – the keyword arguments to pass to the exported function.
- Returns: a pytree of result array, with the same structure as the
results of the exported function.
The invocation supports reverse-mode AD, and all the features supportedby exporting: shape polymorphism, multi-platform, device polymorphism.See the examples in the [JAX export documentation](https://docs.jax.dev/en/latest/export/export.html).
- in_shardings_jax(mesh)[source]#
Creates Shardings corresponding to
self.in_shardings_hloandself._in_named_shardings.The Exported object stores
in_shardings_hloas HloShardings, andafter 12/5/2025 also_in_named_shardingsas NamedShardings withabstract meshes. This method constructsSharding that can be used in JAX APIs such asjax.jit()orjax.device_put(). Themesh argument may be a concrete mesh.Example usage:
>>>fromjaximportexport,sharding>>># Prepare the exported object:>>>exp_mesh=sharding.Mesh(jax.devices(),("a",))>>>exp=export.export(jax.jit(lambdax:jax.numpy.add(x,x),...in_shardings=sharding.NamedSharding(exp_mesh,sharding.PartitionSpec("a")))...)(np.arange(jax.device_count()))>>>exp.in_shardings_hlo({devices=[8]<=[8]},)>>># Create a mesh for running the exported object>>>run_mesh=sharding.Mesh(jax.devices()[::-1],("a",))>>># Put the args and kwargs on the appropriate devices>>>run_arg=jax.device_put(np.arange(jax.device_count()),...exp.in_shardings_jax(run_mesh)[0])>>>res=exp.call(run_arg)>>>res.addressable_shards[Shard(device=CpuDevice(id=7), index=(slice(0, 1, None),), replica_id=0, data=[0]), Shard(device=CpuDevice(id=6), index=(slice(1, 2, None),), replica_id=0, data=[2]), Shard(device=CpuDevice(id=5), index=(slice(2, 3, None),), replica_id=0, data=[4]), Shard(device=CpuDevice(id=4), index=(slice(3, 4, None),), replica_id=0, data=[6]), Shard(device=CpuDevice(id=3), index=(slice(4, 5, None),), replica_id=0, data=[8]), Shard(device=CpuDevice(id=2), index=(slice(5, 6, None),), replica_id=0, data=[10]), Shard(device=CpuDevice(id=1), index=(slice(6, 7, None),), replica_id=0, data=[12]), Shard(device=CpuDevice(id=0), index=(slice(7, 8, None),), replica_id=0, data=[14])]
- Parameters:
mesh (mesh_lib.Mesh)
- Return type:
Sequence[sharding.Sharding | None]
- out_shardings_jax(mesh)[source]#
Creates Shardings for
out_shardings_hloand_out_named_shardings.See documentation for in_shardings_jax.
- Parameters:
mesh (mesh_lib.Mesh)
- Return type:
Sequence[sharding.Sharding | None]
- classjax.export.DisabledSafetyCheck(_impl)[source]#
A safety check that should be skipped on (de)serialization.
Most of these checks are performed on serialization, but some are deferred todeserialization. The list of disabled checks is attached to the serialization,e.g., as a sequence of string attributes to
jax.export.Exportedor oftf.XlaCallModuleOp.When using jax2tf, you can disable more deserialization safety checksby passing
TF_XLA_FLAGS=--tf_xla_call_module_disabled_checks=platform.- Parameters:
_impl (str)
- classmethodcustom_call(target_name)[source]#
Allows the serialization of a call target not known to be stable.
Has effect only on serialization.:param target_name: the name of the custom call target to allow.
- Parameters:
target_name (str)
- Return type:
- is_custom_call()[source]#
Returns the custom call target allowed by this directive.
- Return type:
str | None
Functions#
| Exports a JAX function for persistent serialization. |
| Deserializes an Exported. |
int([x]) -> integer int(x, base=10) -> integer | |
int([x]) -> integer int(x, base=10) -> integer | |
Retrieves the default export platform. | |
| Registers a custom PyTree node for serialization and deserialization. |
| Registers a namedtuple for serialization and deserialization. |
Functions related to shape polymorphism#
| Constructs a symbolic shape from a string representation. |
| Constructs a pytree of jax.ShapeDtypeStruct arguments specs forexport. |
Checks if a dimension is symbolic. | |
| Identifies a scope for symbolic expressions. |
Constants#
- jax.export.minimum_supported_serialization_version#
The minimum supported serialization version; seeCalling convention versions.
- jax.export.maximum_supported_serialization_version#
The maximum supported serialization version; seeCalling convention versions.
