Attaching XLA Metadata with set_xla_metadata
Contents
Attaching XLA Metadata withset_xla_metadata#
Summary:set_xla_metadata allows you to attach metadata to operations in your JAX code. This metadata is passed down to the XLA compiler asfrontend_attributes and can be used to enable compiler-level debugging tools, such as the XLA-TPU debugger.
You can use it in three ways:
Tag an individual operation by wrapping its output value
Tag a block of operations using a context manager
Tag all operations in a function using a decorator
Warning:set_xla_metadata is an experimental feature and its API is subject to change.
What is XLA Metadata?#
When JAX transforms and compiles your code, it ultimately generates an XLA (Accelerated Linear Algebra) computation graph. Each operation in this graph can have associated metadata, specificallyfrontend_attributes. This metadata doesn’t change the numerical result of the operation, but it can be used to signal special behavior to the compiler or runtime.
set_xla_metadata provides a way to attach this metadata directly from your JAX code. This is a powerful feature for low-level debugging and profiling.
Usage#
Tagging Individual Operations#
Tagging an individual operation gives you precise control over which parts of your computation you want to inspect. To do this, you wrap the output (value) of an operation withset_xla_metadata. When wrapping a function with multiple operations within, only the final operation of said function will be tagged.
importjaximportjax.numpyasjnpfromjax.experimental.xla_metadataimportset_xla_metadata# Tagging an individual operationdefvalue_tagging(x):y=jnp.sin(x)z=jnp.cos(x)returnset_xla_metadata(y*z,breakpoint=True)print(jax.jit(value_tagging).lower(1.0).as_text("hlo"))
Results in:
ENTRYmain.5{x.1=f32[]parameter(0)sin.2=f32[]sine(x.1)cos.3=f32[]cosine(x.1)ROOTmul.4=f32[]multiply(sin.2,cos.3),frontend_attributes={breakpoint="true"}}
Tagging a Block of Code with a Context Manager or Decorator#
If you want to apply the same metadata to a larger section of code, you can useset_xla_metadata as a context manager. All JAX operations within thewith block will have the specified metadata attached.
importjaximportjax.numpyasjnpfromjax.experimental.xla_metadataimportset_xla_metadata# Tagging a block of codedefcontext_tagging(x):withset_xla_metadata(_xla_log=True):y=jnp.sin(x)z=jnp.cos(y)returny*zprint(jax.jit(context_tagging).lower(1.0).as_text("hlo"))
Results in:
ENTRYmain.5{x.1=f32[]parameter(0)sin.2=f32[]sine(x.1),frontend_attributes={_xla_log="true"}cos.3=f32[]cosine(sin.2),frontend_attributes={_xla_log="true"}ROOTmul.4=f32[]multiply(sin.2,cos.3),frontend_attributes={_xla_log="true"}}
If you want to tag all operations in a function, you can also useset_xla_metadata as a decorator:
importjaximportjax.numpyasjnpfromjax.experimental.xla_metadataimportset_xla_metadata# Tagging with a decorator@set_xla_metadata(_xla_log=True)@jax.jitdefdecorator_tagging(x):y=jnp.sin(x)z=jnp.cos(y)returny*zprint(decorator_tagging.lower(1.0).as_text("hlo"))
This will result in the same HLO as above.
Interaction with JAX Transformations#
set_xla_metadata utilizes either aXlaMetadataContextManager or JAXprimitive depending on use-case and is compatible with JAX’s transformations likejit,vmap, andgrad.
vmap: When youvmapa function containingset_xla_metadata, the metadata will be applied to all of the relevant batched operations.grad:When tagging a block of operations with thecontext manager
withset_xla_metadata(...):, the metadata is applied to both the forward pass and backward pass of the operations within it.Taggingindividual ops with
set_xla_metadata()currently only applies to the forward pass of a function. To tag individual operations generated by the backward pass (i.e., the gradient computation), a simplecustom_vjpcan be used:importjaximportjax.numpyasjnpfromjax.experimental.xla_metadataimportset_xla_metadatadeffn(x):y=jnp.sin(x)z=jnp.cos(x)returny*zmetadata={"example":"grad_tagging"}# --- Define Custom VJP to tag gradients ---@jax.custom_vjpdefwrapped_fn(x):returnfn(x)deffwd(*args):primal_out,vjp_fn=jax.vjp(fn,*args)returnprimal_out,vjp_fndefbwd(vjp_fn,cts_in):cts_out=vjp_fn(cts_in)cts_out=set_xla_metadata(cts_out,**metadata)returncts_outwrapped_fn.defvjp(fwd,bwd)# ------print(jax.jit(jax.grad(wrapped_fn)).lower(jnp.array(3.0)).as_text("hlo"))
Results in:
ENTRYmain.10{x.1=f32[]parameter(0)sin.2=f32[]sine(x.1)neg.6=f32[]negate(sin.2)sin.5=f32[]sine(x.1)mul.7=f32[]multiply(neg.6,sin.5)cos.4=f32[]cosine(x.1)cos.3=f32[]cosine(x.1)mul.8=f32[]multiply(cos.4,cos.3)ROOTadd_any.9=f32[]add(mul.7,mul.8),frontend_attributes={example="grad_tagging"}}
Strengths and Limitations ofset_xla_metadata#
Strengths#
Variable Control: Allows you to target individual operations or blocks of operations.
Non-Intrusive: Does not change the numerical output or fusion behavior of your program.
Enables Powerful Tooling: Unlocks the potential for sophisticated debugging and analysis at the compiler level.
Limitations#
Attributes may be lost: While it’s intended for XLA metadata to be maintained throughout transformations and HLO optimizations, certain edge-cases may result in the metadata being lost.
Forward-pass only: Metadata is not currently automatically propagated to gradientswhen tagging individual operations in the backward pass. A
custom_vjpmust be used in order to tag gradients in this case. See above for an example.Liable to change:
set_xla_metadatais an experimental feature and its API is subject to change.
