Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

Attaching XLA Metadata withset_xla_metadata#

Summary:set_xla_metadata allows you to attach metadata to operations in your JAX code. This metadata is passed down to the XLA compiler asfrontend_attributes and can be used to enable compiler-level debugging tools, such as the XLA-TPU debugger.

You can use it in three ways:

  1. Tag an individual operation by wrapping its output value

  2. Tag a block of operations using a context manager

  3. Tag all operations in a function using a decorator

Warning:set_xla_metadata is an experimental feature and its API is subject to change.

What is XLA Metadata?#

When JAX transforms and compiles your code, it ultimately generates an XLA (Accelerated Linear Algebra) computation graph. Each operation in this graph can have associated metadata, specificallyfrontend_attributes. This metadata doesn’t change the numerical result of the operation, but it can be used to signal special behavior to the compiler or runtime.

set_xla_metadata provides a way to attach this metadata directly from your JAX code. This is a powerful feature for low-level debugging and profiling.

Usage#

Tagging Individual Operations#

Tagging an individual operation gives you precise control over which parts of your computation you want to inspect. To do this, you wrap the output (value) of an operation withset_xla_metadata. When wrapping a function with multiple operations within, only the final operation of said function will be tagged.

importjaximportjax.numpyasjnpfromjax.experimental.xla_metadataimportset_xla_metadata# Tagging an individual operationdefvalue_tagging(x):y=jnp.sin(x)z=jnp.cos(x)returnset_xla_metadata(y*z,breakpoint=True)print(jax.jit(value_tagging).lower(1.0).as_text("hlo"))

Results in:

ENTRYmain.5{x.1=f32[]parameter(0)sin.2=f32[]sine(x.1)cos.3=f32[]cosine(x.1)ROOTmul.4=f32[]multiply(sin.2,cos.3),frontend_attributes={breakpoint="true"}}

Tagging a Block of Code with a Context Manager or Decorator#

If you want to apply the same metadata to a larger section of code, you can useset_xla_metadata as a context manager. All JAX operations within thewith block will have the specified metadata attached.

importjaximportjax.numpyasjnpfromjax.experimental.xla_metadataimportset_xla_metadata# Tagging a block of codedefcontext_tagging(x):withset_xla_metadata(_xla_log=True):y=jnp.sin(x)z=jnp.cos(y)returny*zprint(jax.jit(context_tagging).lower(1.0).as_text("hlo"))

Results in:

ENTRYmain.5{x.1=f32[]parameter(0)sin.2=f32[]sine(x.1),frontend_attributes={_xla_log="true"}cos.3=f32[]cosine(sin.2),frontend_attributes={_xla_log="true"}ROOTmul.4=f32[]multiply(sin.2,cos.3),frontend_attributes={_xla_log="true"}}

If you want to tag all operations in a function, you can also useset_xla_metadata as a decorator:

importjaximportjax.numpyasjnpfromjax.experimental.xla_metadataimportset_xla_metadata# Tagging with a decorator@set_xla_metadata(_xla_log=True)@jax.jitdefdecorator_tagging(x):y=jnp.sin(x)z=jnp.cos(y)returny*zprint(decorator_tagging.lower(1.0).as_text("hlo"))

This will result in the same HLO as above.

Interaction with JAX Transformations#

set_xla_metadata utilizes either aXlaMetadataContextManager or JAXprimitive depending on use-case and is compatible with JAX’s transformations likejit,vmap, andgrad.

  • vmap: When youvmap a function containingset_xla_metadata, the metadata will be applied to all of the relevant batched operations.

  • grad:

    1. When tagging a block of operations with thecontext managerwithset_xla_metadata(...):, the metadata is applied to both the forward pass and backward pass of the operations within it.

    2. Taggingindividual ops withset_xla_metadata() currently only applies to the forward pass of a function. To tag individual operations generated by the backward pass (i.e., the gradient computation), a simplecustom_vjp can be used:

      importjaximportjax.numpyasjnpfromjax.experimental.xla_metadataimportset_xla_metadatadeffn(x):y=jnp.sin(x)z=jnp.cos(x)returny*zmetadata={"example":"grad_tagging"}# --- Define Custom VJP to tag gradients ---@jax.custom_vjpdefwrapped_fn(x):returnfn(x)deffwd(*args):primal_out,vjp_fn=jax.vjp(fn,*args)returnprimal_out,vjp_fndefbwd(vjp_fn,cts_in):cts_out=vjp_fn(cts_in)cts_out=set_xla_metadata(cts_out,**metadata)returncts_outwrapped_fn.defvjp(fwd,bwd)# ------print(jax.jit(jax.grad(wrapped_fn)).lower(jnp.array(3.0)).as_text("hlo"))

      Results in:

      ENTRYmain.10{x.1=f32[]parameter(0)sin.2=f32[]sine(x.1)neg.6=f32[]negate(sin.2)sin.5=f32[]sine(x.1)mul.7=f32[]multiply(neg.6,sin.5)cos.4=f32[]cosine(x.1)cos.3=f32[]cosine(x.1)mul.8=f32[]multiply(cos.4,cos.3)ROOTadd_any.9=f32[]add(mul.7,mul.8),frontend_attributes={example="grad_tagging"}}

Strengths and Limitations ofset_xla_metadata#

Strengths#

  • Variable Control: Allows you to target individual operations or blocks of operations.

  • Non-Intrusive: Does not change the numerical output or fusion behavior of your program.

  • Enables Powerful Tooling: Unlocks the potential for sophisticated debugging and analysis at the compiler level.

Limitations#

  • Attributes may be lost: While it’s intended for XLA metadata to be maintained throughout transformations and HLO optimizations, certain edge-cases may result in the metadata being lost.

  • Forward-pass only: Metadata is not currently automatically propagated to gradientswhen tagging individual operations in the backward pass. Acustom_vjp must be used in order to tag gradients in this case. See above for an example.

  • Liable to change:set_xla_metadata is an experimental feature and its API is subject to change.


[8]ページ先頭

©2009-2026 Movatter.jp