Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.export module

Contents

jax.export module#

jax.export is a library for exporting and serializing JAX functionsfor persistent archival.

See theExporting and serialization documentation.

Classes#

classjax.export.Exported(fun_name,in_tree,in_avals,out_tree,out_avals,_has_named_shardings,_in_named_shardings,_out_named_shardings,in_shardings_hlo,out_shardings_hlo,nr_devices,platforms,ordered_effects,unordered_effects,disabled_safety_checks,mlir_module_serialized,calling_convention_version,module_kept_var_idx,uses_global_constants,_get_vjp)[source]#

A JAX function lowered to StableHLO.

Parameters:
fun_name#

the name of the exported function, for error messages.

Type:

str

in_tree#

a PyTreeDef describing the tuple (args, kwargs) of the lowered JAXfunction. The actual lowering does not depend on thein_tree, but thiscan be used to invoke the exported function using the same argumentstructure.

Type:

tree_util.PyTreeDef

in_avals#

the flat tuple of input abstract values. May contain dimensionexpressions in the shapes.

Type:

tuple[core.ShapedArray, …]

out_tree#

a PyTreeDef describing the result of the lowered JAX function.

Type:

tree_util.PyTreeDef

out_avals#

the flat tuple of output abstract values. May contain dimensionexpressions in the shapes, with dimension variables among those inin_avals. Note that when the out_shardings are not specified foran output, theout_avals.sharding.spec forAuto axes may beNoneeven if after compilation the compiler may pick a non-replicatedsharding.Seehttps://docs.jax.dev/en/latest/notebooks/explicit-sharding.html#concrete-array-shardings-can-mention-auto-mesh-axisfor more details.

Type:

tuple[core.ShapedArray, …]

in_shardings_hlo#

the flattened input shardings, a sequence as longasin_avals.None means unspecified sharding.Note that these do not include the mesh or the actual devices used inthe mesh, and in general you should avoid using this field directly.Seein_shardings_jax for a way to turn theseinto sharding specification that can be used with JAX APIs.

Type:

tuple[HloSharding | None, …]

out_shardings_hlo#

the flattened output shardings, a sequence as longasout_avals.None means unspecified sharding.Note that these do not include the mesh or the actual devices used inthe mesh, and in general you should avoid using this field directly.Seeout_shardings_jax for a way to turn theseinto sharding specification that can be used with JAX APIs.

Type:

tuple[HloSharding | None, …]

nr_devices#

the number of devices that the module has been lowered for.

Type:

int

platforms#

a tuple containing the platforms for which the function shouldbe exported. The set of platforms in JAX is open-ended; users canadd platforms. JAX built-in platforms are: ‘tpu’, ‘cpu’, ‘cuda’, ‘rocm’.Seehttps://docs.jax.dev/en/latest/export/export.html#cross-platform-and-multi-platform-export.

Type:

tuple[str, …]

ordered_effects#

the ordered effects present in the serialized module.This is present from serialization version 9. Seehttps://docs.jax.dev/en/latest/export/export.html#module-calling-conventionfor the calling convention in presence of ordered effects.

Type:

tuple[effects.Effect, …]

unordered_effects#

the unordered effects present in the serialized module.This is present from serialization version 9.

Type:

tuple[effects.Effect, …]

mlir_module_serialized#

the serialized lowered VHLO module.

Type:

bytes

calling_convention_version#

a version number for the callingconvention of the exported module.See more versioning details athttps://docs.jax.dev/en/latest/export/export.html#calling-convention-versions.

Type:

int

module_kept_var_idx#

the sorted indices of the arguments amongin_avals thatmust be passed to the module. The other arguments have been droppedbecause they are not used.

Type:

tuple[int, …]

uses_global_constants#

whether themlir_module_serialized uses shapepolymorphism or multi-platform export.This may be becausein_avals contains dimensionvariables, or due to inner calls of Exported modules that havedimension variables or platform index arguments. Such modules needshape refinement before XLA compilation.

Type:

bool

disabled_safety_checks#

a list of descriptors of safety checks that have beendisabled at export time. See docstring forDisabledSafetyCheck.

Type:

Sequence[DisabledSafetyCheck]

_get_vjp#

an optional function that takes the current exported function andreturns the exported VJP function.The VJP function takes a flat list of arguments,starting with the primal arguments and followed by a cotangent argumentfor each primal output. It returns a tuple with the cotangentscorresponding to the flattened primal inputs.

Type:

Callable[[Exported],Exported] | None

DO NOT RELY directly on fields whose name starts with ‘_’. They will change.

See a description of the calling convention for themlir_module()method athttps://docs.jax.dev/en/latest/export/export.html#module-calling-convention.

call(*args,**kwargs)[source]#

Call an exported function from a JAX program.

Parameters:
  • args – the positional arguments to pass to the exported function. Thisshould be a pytree of arrays with the same pytree structure as thearguments for which the function was exported.

  • kwargs – the keyword arguments to pass to the exported function.

Returns: a pytree of result array, with the same structure as the

results of the exported function.

The invocation supports reverse-mode AD, and all the features supportedby exporting: shape polymorphism, multi-platform, device polymorphism.See the examples in the [JAX export documentation](https://docs.jax.dev/en/latest/export/export.html).

has_vjp()[source]#

Returns if this Exported supports VJP.

Return type:

bool

in_shardings_jax(mesh)[source]#

Creates Shardings corresponding toself.in_shardings_hlo andself._in_named_shardings.

The Exported object storesin_shardings_hlo as HloShardings, andafter 12/5/2025 also_in_named_shardings as NamedShardings withabstract meshes. This method constructsSharding that can be used in JAX APIs such asjax.jit() orjax.device_put(). Themesh argument may be a concrete mesh.

Example usage:

>>>fromjaximportexport,sharding>>># Prepare the exported object:>>>exp_mesh=sharding.Mesh(jax.devices(),("a",))>>>exp=export.export(jax.jit(lambdax:jax.numpy.add(x,x),...in_shardings=sharding.NamedSharding(exp_mesh,sharding.PartitionSpec("a")))...)(np.arange(jax.device_count()))>>>exp.in_shardings_hlo({devices=[8]<=[8]},)>>># Create a mesh for running the exported object>>>run_mesh=sharding.Mesh(jax.devices()[::-1],("a",))>>># Put the args and kwargs on the appropriate devices>>>run_arg=jax.device_put(np.arange(jax.device_count()),...exp.in_shardings_jax(run_mesh)[0])>>>res=exp.call(run_arg)>>>res.addressable_shards[Shard(device=CpuDevice(id=7), index=(slice(0, 1, None),), replica_id=0, data=[0]), Shard(device=CpuDevice(id=6), index=(slice(1, 2, None),), replica_id=0, data=[2]), Shard(device=CpuDevice(id=5), index=(slice(2, 3, None),), replica_id=0, data=[4]), Shard(device=CpuDevice(id=4), index=(slice(3, 4, None),), replica_id=0, data=[6]), Shard(device=CpuDevice(id=3), index=(slice(4, 5, None),), replica_id=0, data=[8]), Shard(device=CpuDevice(id=2), index=(slice(5, 6, None),), replica_id=0, data=[10]), Shard(device=CpuDevice(id=1), index=(slice(6, 7, None),), replica_id=0, data=[12]), Shard(device=CpuDevice(id=0), index=(slice(7, 8, None),), replica_id=0, data=[14])]
Parameters:

mesh (mesh_lib.Mesh)

Return type:

Sequence[sharding.Sharding | None]

mlir_module()[source]#

A string representation of themlir_module_serialized.

Return type:

str

out_shardings_jax(mesh)[source]#

Creates Shardings forout_shardings_hlo and_out_named_shardings.

See documentation for in_shardings_jax.

Parameters:

mesh (mesh_lib.Mesh)

Return type:

Sequence[sharding.Sharding | None]

serialize(vjp_order=0)[source]#

Serializes an Exported.

Parameters:

vjp_order (int) – The maximum vjp order to include. E.g., the value 2 means that weserialize the primal functions and two orders of thevjp function. Thisshould allow 2nd order reverse mode differentiation of the deserializedfunction. i.e.,jax.grad(jax.grad(f)).

Return type:

bytearray

vjp()[source]#

Gets the exported VJP.

Returns None if not available, which can happen if the Exported has beenloaded from an external format without a VJP.

Return type:

Exported

classjax.export.DisabledSafetyCheck(_impl)[source]#

A safety check that should be skipped on (de)serialization.

Most of these checks are performed on serialization, but some are deferred todeserialization. The list of disabled checks is attached to the serialization,e.g., as a sequence of string attributes tojax.export.Exported or oftf.XlaCallModuleOp.

When using jax2tf, you can disable more deserialization safety checksby passingTF_XLA_FLAGS=--tf_xla_call_module_disabled_checks=platform.

Parameters:

_impl (str)

classmethodcustom_call(target_name)[source]#

Allows the serialization of a call target not known to be stable.

Has effect only on serialization.:param target_name: the name of the custom call target to allow.

Parameters:

target_name (str)

Return type:

DisabledSafetyCheck

is_custom_call()[source]#

Returns the custom call target allowed by this directive.

Return type:

str | None

classmethodplatform()[source]#

Allows the compilation platform to differ from the export platform.

Has effect only on deserialization.

Return type:

DisabledSafetyCheck

Functions#

export(fun_jit, *[, platforms, ...])

Exports a JAX function for persistent serialization.

deserialize(blob)

Deserializes an Exported.

minimum_supported_calling_convention_version

int([x]) -> integer int(x, base=10) -> integer

maximum_supported_calling_convention_version

int([x]) -> integer int(x, base=10) -> integer

default_export_platform()

Retrieves the default export platform.

register_pytree_node_serialization(nodetype, ...)

Registers a custom PyTree node for serialization and deserialization.

register_namedtuple_serialization(nodetype, ...)

Registers a namedtuple for serialization and deserialization.

Functions related to shape polymorphism#

symbolic_shape(shape_spec, *[, constraints, ...])

Constructs a symbolic shape from a string representation.

symbolic_args_specs(args, shapes_specs[, ...])

Constructs a pytree of jax.ShapeDtypeStruct arguments specs forexport.

is_symbolic_dim(p)

Checks if a dimension is symbolic.

SymbolicScope([constraints_str])

Identifies a scope for symbolic expressions.

Constants#

jax.export.minimum_supported_serialization_version#

The minimum supported serialization version; seeCalling convention versions.

jax.export.maximum_supported_serialization_version#

The maximum supported serialization version; seeCalling convention versions.

Contents

[8]ページ先頭

©2009-2026 Movatter.jp