jax.stages module
Contents
jax.stages module#
Interfaces to stages of the compiled execution process.
JAX transformations that compile just in time for execution, such asjax.jit andjax.pmap, also support a common means of explicitlowering and compilationahead of time. This module defines typesthat represent the stages of this process.
For more, see theAOT walkthrough.
Classes#
- classjax.stages.Wrapped(*args,**kwargs)[source]#
A function ready to be traced, lowered, and compiled.
This protocol reflects the output of functions such as
jax.jit. Calling it results in JIT (just-in-time) lowering,compilation, and execution. It can also be explicitly lowered priorto compilation, and the result compiled prior to execution.- lower(*args,**kwargs)[source]#
Lower this function explicitly for the given arguments.
This is a shortcut for
self.trace(*args,**kwargs).lower().A lowered function is staged out of Python and translated to acompiler’s input language, possibly in a backend-dependentmanner. It is ready for compilation but not yet compiled.
- Returns:
A
Loweredinstance representing the lowering.- Return type:
- classjax.stages.Traced(meta_tys_flat,params,in_tree,out_tree,consts)[source]#
Traced form of a function specialized to argument types and values.
A traced computation is ready for lowering. This class carries thetraced representation with the remaining information needed to laterlower, compile, and execute it.
Provides access to both the hijax (high-level) and lojax (low-level)representations via.jaxpr and.lojax properties respectively.
- classjax.stages.Lowered(lowering,args_info,out_tree,no_kwargs=False,in_types=None,out_types=None)[source]#
Lowering of a function specialized to argument types and values.
A lowering is a computation ready for compilation. This classcarries a lowering together with the remaining information needed tolater compile and execute it. It also provides a common API forquerying properties of lowered computations across JAX’s variouslowering paths (
jit(),pmap(), etc.).- Parameters:
lowering (Lowering)
args_info (Any)
out_tree (tree_util.PyTreeDef)
no_kwargs (bool)
- as_text(dialect=None,*,debug_info=False)[source]#
A human-readable text representation of this lowering.
Intended for visualization and debugging purposes. This need not be a validnor reliable serialization.Usejax.export if you want reliable and portable serialization.
- compile(compiler_options=None,*,device_assignment=None)[source]#
Compile, returning a corresponding
Compiledinstance.
- compiler_ir(dialect=None)[source]#
An arbitrary object representation of this lowering.
Intended for debugging purposes. This is not a valid nor reliableserialization. The output has no guarantee of consistency acrossinvocations.Usejax.export if you want reliable and portable serialization.
Returns
Noneif unavailable, e.g. based on backend, compiler, orruntime.- Parameters:
dialect (str |None) – Optional string specifying a lowering dialect (e.g. “stablehlo”,or “hlo”).
- Return type:
Any | None
- cost_analysis()[source]#
A summary of execution cost estimates.
Intended for visualization and debugging purposes. The object output bythis is some simple data structure that can easily be printed or serialized(e.g. nested dicts, lists, and tuples with numeric leaves). However, itsstructure can be arbitrary: it may be inconsistent across versions of JAXand jaxlib, or even across invocations.
Returns
Noneif unavailable, e.g. based on backend, compiler, orruntime.- Return type:
Any | None
- classjax.stages.Compiled(executable,const_args,args_info,out_tree,no_kwargs=False,in_types=None,out_types=None)[source]#
Compiled representation of a function specialized to types/values.
A compiled computation is associated with an executable and theremaining information needed to execute it. It also provides acommon API for querying properties of compiled computations acrossJAX’s various compilation paths and backends.
- Parameters:
const_args (list[ArrayLike])
args_info (Any)
out_tree (tree_util.PyTreeDef)
- as_text()[source]#
A human-readable text representation of this executable.
Intended for visualization and debugging purposes. This is not a valid norreliable serialization.
Returns
Noneif unavailable, e.g. based on backend, compiler, orruntime.- Return type:
str | None
- cost_analysis()[source]#
A summary of execution cost estimates.
Intended for visualization and debugging purposes. The object output bythis is some simple data structure that can easily be printed or serialized(e.g. nested dicts, lists, and tuples with numeric leaves). However, itsstructure can be arbitrary: it may be inconsistent across versions of JAXand jaxlib, or even across invocations.
Returns
Noneif unavailable, e.g. based on backend, compiler, orruntime.- Return type:
Any | None
- propertyin_tree:tree_util.PyTreeDef[source]#
Tree structure of the pair (positional arguments, keyword arguments).
- memory_analysis()[source]#
A summary of estimated memory requirements.
Intended for visualization and debugging purposes. The object output bythis is some simple data structure that can easily be printed or serialized(e.g. nested dicts, lists, and tuples with numeric leaves). However, itsstructure can be arbitrary: it may be inconsistent across versions of JAXand jaxlib, or even across invocations.
Returns
Noneif unavailable, e.g. based on backend, compiler, orruntime.- Return type:
Any | None
- runtime_executable()[source]#
An arbitrary object representation of this executable.
Intended for debugging purposes. This is not valid nor reliableserialization. The output has no guarantee of consistency acrossinvocations.
Returns
Noneif unavailable, e.g. based on backend, compiler, orruntime.- Return type:
Any | None
