Debugging runtime values
Contents
Debugging runtime values#
Do you have exploding gradients? Are NaNs making you gnash your teeth? Just want to poke around the intermediate values in your computation? Check out the following JAX debugging tools! This page has summaries and you can click the “Read more” links at the bottom to learn more.
Table of contents:
Interactive inspection withjax.debug#
Complete guidehere
Summary: Usejax.debug.print() to print values to stdout injax.jit-,jax.pmap-, andpjit-decorated functions,andjax.debug.breakpoint() to pause execution of your compiled function to inspect values in the call stack:
importjaximportjax.numpyasjnp@jax.jitdeff(x):jax.debug.print("🤯{x} 🤯",x=x)y=jnp.sin(x)jax.debug.breakpoint()jax.debug.print("🤯{y} 🤯",y=y)returnyf(2.)# Prints:# 🤯 2.0 🤯# Enters breakpoint to inspect values!# 🤯 0.9092974662780762 🤯
Functional error checks withjax.experimental.checkify#
Complete guidehere
Summary: Checkify lets you addjit-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use thecheckify.checkify transformation together with the assert-likecheckify.check function to add runtime checks to JAX code:
fromjax.experimentalimportcheckifyimportjaximportjax.numpyasjnpdeff(x,i):checkify.check(i>=0,"index needs to be non-negative!")y=x[i]z=jnp.sin(y)returnzjittable_f=checkify.checkify(f)err,z=jax.jit(jittable_f)(jnp.ones((5,)),-1)print(err.get())# >> index needs to be non-negative! (check failed at <...>:6 (f))
You can also use checkify to automatically add common checks:
errors=checkify.user_checks|checkify.index_checks|checkify.float_checkschecked_f=checkify.checkify(f,errors=errors)err,z=checked_f(jnp.ones((5,)),100)err.throw()# ValueError: out-of-bounds indexing at <..>:7 (f)err,z=checked_f(jnp.ones((5,)),-1)err.throw()# ValueError: index needs to be non-negative! (check failed at <…>:6 (f))err,z=checked_f(jnp.array([jnp.inf,1]),0)err.throw()# ValueError: nan generated by primitive sin at <...>:8 (f)
Throwing Python errors with JAX’s debug flags#
Complete guidehere
Summary: Enable thejax_debug_nans flag to automatically detect when NaNs are produced injax.jit-compiled code (but not injax.pmap orjax.pjit-compiled code) and enable thejax_disable_jit flag to disable JIT-compilation, enabling use of traditional Python debugging tools likeprint andpdb.
importjaxjax.config.update("jax_debug_nans",True)deff(x,y):returnx/yjax.jit(f)(0.,0.)# ==> raises FloatingPointError exception!
Attaching XLA Metadata withset_xla_metadata#
Complete guidehere
Summary:set_xla_metadata allows you to attach metadata to operations in your JAX code. This metadata is passed down to the XLA compiler asfrontend_attributes and can be used to enable compiler-level debugging tools, such as the XLA-TPU debugger.
Note:set_xla_metadata is an experimental feature and its API is subject to change.
importjaximportjax.numpyasjnpfromjax.experimental.xla_metadataimportset_xla_metadata# Tagging an individual operationdefvalue_tagging(x):y=jnp.sin(x)z=jnp.cos(x)returnset_xla_metadata(y*z,breakpoint=True)print(jax.jit(value_tagging).lower(1.0).as_text("hlo"))
Results in:
ENTRYmain.5{x.1=f32[]parameter(0)sin.2=f32[]sine(x.1)cos.3=f32[]cosine(x.1)ROOTmul.4=f32[]multiply(sin.2,cos.3),frontend_attributes={breakpoint="true"}}
