Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.stages module#

Interfaces to stages of the compiled execution process.

JAX transformations that compile just in time for execution, such asjax.jit andjax.pmap, also support a common means of explicitlowering and compilationahead of time. This module defines typesthat represent the stages of this process.

For more, see theAOT walkthrough.

Classes#

classjax.stages.Wrapped(*args,**kwargs)[source]#

A function ready to be traced, lowered, and compiled.

This protocol reflects the output of functions such asjax.jit. Calling it results in JIT (just-in-time) lowering,compilation, and execution. It can also be explicitly lowered priorto compilation, and the result compiled prior to execution.

__call__(*args,**kwargs)[source]#

Executes the wrapped function, lowering and compiling as needed.

lower(*args,**kwargs)[source]#

Lower this function explicitly for the given arguments.

This is a shortcut forself.trace(*args,**kwargs).lower().

A lowered function is staged out of Python and translated to acompiler’s input language, possibly in a backend-dependentmanner. It is ready for compilation but not yet compiled.

Returns:

ALowered instance representing the lowering.

Return type:

Lowered

trace(*args,**kwargs)[source]#

Trace this function explicitly for the given arguments.

A traced function is staged out of Python and translated to a jaxpr. It isready for lowering but not yet lowered.

Returns:

ATraced instance representing the tracing.

Return type:

Traced

classjax.stages.Traced(meta_tys_flat,params,in_tree,out_tree,consts)[source]#

Traced form of a function specialized to argument types and values.

A traced computation is ready for lowering. This class carries thetraced representation with the remaining information needed to laterlower, compile, and execute it.

Provides access to both the hijax (high-level) and lojax (low-level)representations via.jaxpr and.lojax properties respectively.

lower(*,lowering_platforms=None,_private_parameters=None)[source]#

Lower to compiler input, returning aLowered instance.

Parameters:
  • lowering_platforms (tuple[str,...]|None)

  • _private_parameters (mlir.LoweringParameters |None)

classjax.stages.Lowered(lowering,args_info,out_tree,no_kwargs=False,in_types=None,out_types=None)[source]#

Lowering of a function specialized to argument types and values.

A lowering is a computation ready for compilation. This classcarries a lowering together with the remaining information needed tolater compile and execute it. It also provides a common API forquerying properties of lowered computations across JAX’s variouslowering paths (jit(),pmap(), etc.).

Parameters:
  • lowering (Lowering)

  • args_info (Any)

  • out_tree (tree_util.PyTreeDef)

  • no_kwargs (bool)

as_text(dialect=None,*,debug_info=False)[source]#

A human-readable text representation of this lowering.

Intended for visualization and debugging purposes. This need not be a validnor reliable serialization.Usejax.export if you want reliable and portable serialization.

Parameters:
  • dialect (str |None) – Optional string specifying a lowering dialect (e.g. “stablehlo”,or “hlo”).

  • debug_info (bool) – Whether to include debugging information,e.g., source location.

Return type:

str

compile(compiler_options=None,*,device_assignment=None)[source]#

Compile, returning a correspondingCompiled instance.

Parameters:
  • compiler_options (CompilerOptions |None)

  • device_assignment (tuple[xc.Device,...]|None)

Return type:

Compiled

compiler_ir(dialect=None)[source]#

An arbitrary object representation of this lowering.

Intended for debugging purposes. This is not a valid nor reliableserialization. The output has no guarantee of consistency acrossinvocations.Usejax.export if you want reliable and portable serialization.

ReturnsNone if unavailable, e.g. based on backend, compiler, orruntime.

Parameters:

dialect (str |None) – Optional string specifying a lowering dialect (e.g. “stablehlo”,or “hlo”).

Return type:

Any | None

cost_analysis()[source]#

A summary of execution cost estimates.

Intended for visualization and debugging purposes. The object output bythis is some simple data structure that can easily be printed or serialized(e.g. nested dicts, lists, and tuples with numeric leaves). However, itsstructure can be arbitrary: it may be inconsistent across versions of JAXand jaxlib, or even across invocations.

ReturnsNone if unavailable, e.g. based on backend, compiler, orruntime.

Return type:

Any | None

propertyin_tree:tree_util.PyTreeDef[source]#

Tree structure of the pair (positional arguments, keyword arguments).

classjax.stages.Compiled(executable,const_args,args_info,out_tree,no_kwargs=False,in_types=None,out_types=None)[source]#

Compiled representation of a function specialized to types/values.

A compiled computation is associated with an executable and theremaining information needed to execute it. It also provides acommon API for querying properties of compiled computations acrossJAX’s various compilation paths and backends.

Parameters:
  • const_args (list[ArrayLike])

  • args_info (Any)

  • out_tree (tree_util.PyTreeDef)

__call__(*args,**kwargs)[source]#

Call self as a function.

as_text()[source]#

A human-readable text representation of this executable.

Intended for visualization and debugging purposes. This is not a valid norreliable serialization.

ReturnsNone if unavailable, e.g. based on backend, compiler, orruntime.

Return type:

str | None

cost_analysis()[source]#

A summary of execution cost estimates.

Intended for visualization and debugging purposes. The object output bythis is some simple data structure that can easily be printed or serialized(e.g. nested dicts, lists, and tuples with numeric leaves). However, itsstructure can be arbitrary: it may be inconsistent across versions of JAXand jaxlib, or even across invocations.

ReturnsNone if unavailable, e.g. based on backend, compiler, orruntime.

Return type:

Any | None

propertyin_tree:tree_util.PyTreeDef[source]#

Tree structure of the pair (positional arguments, keyword arguments).

memory_analysis()[source]#

A summary of estimated memory requirements.

Intended for visualization and debugging purposes. The object output bythis is some simple data structure that can easily be printed or serialized(e.g. nested dicts, lists, and tuples with numeric leaves). However, itsstructure can be arbitrary: it may be inconsistent across versions of JAXand jaxlib, or even across invocations.

ReturnsNone if unavailable, e.g. based on backend, compiler, orruntime.

Return type:

Any | None

runtime_executable()[source]#

An arbitrary object representation of this executable.

Intended for debugging purposes. This is not valid nor reliableserialization. The output has no guarantee of consistency acrossinvocations.

ReturnsNone if unavailable, e.g. based on backend, compiler, orruntime.

Return type:

Any | None


[8]ページ先頭

©2009-2026 Movatter.jp