Introduction to debugging
Contents
Introduction to debugging#
Do you have exploding gradients? Are NaNs making you gnash your teeth? Just wantto poke around the intermediate values in your computation? This sectionintroduces you to a set of built-in JAX debugging methods that you can use withvarious JAX transformations.
Summary:
Use
jax.debug.print()to print values to stdout injax.jit-,jax.pmap-, andpjit-decorated functions,andjax.debug.breakpoint()to pause execution of your compiled function to inspect values in the call stack.jax.experimental.checkifylets you addjit-able runtime error checking (e.g. out of bounds indexing) to your JAXcode.JAX offers config flags and context managers that enable catching errors more easily. For example, enable the
jax_debug_nansflag to automatically detect when NaNs are produced injax.jit-compiled code and enable thejax_disable_jitflag to disable JIT-compilation.
jax.debug.print for simple inspection#
Here is a rule of thumb:
Use
jax.debug.print()for traced (dynamic) array values withjax.jit(),jax.vmap()and others.Use Python
print()for static values, such as dtypes and array shapes.
Recall fromJust-in-time compilation that when transforming a function withjax.jit(),the Python code is executed with abstract tracers in place of your arrays. Because of this,the Pythonprint() function will only print this tracer value:
importjaximportjax.numpyasjnp@jax.jitdeff(x):print("print(x) ->",x)y=jnp.sin(x)print("print(y) ->",y)returnyresult=f(2.)
print(x) -> JitTracer<~float32[]>print(y) -> JitTracer<~float32[]>
Python’sprint executes at trace-time, before the runtime values exist.If you want to print the actual runtime values, you can usejax.debug.print():
@jax.jitdeff(x):jax.debug.print("jax.debug.print(x) ->{x}",x=x)y=jnp.sin(x)jax.debug.print("jax.debug.print(y) ->{y}",y=y)returnyresult=f(2.)
jax.debug.print(x) -> 2.0jax.debug.print(y) -> 0.9092974066734314
Similarly, withinjax.vmap(), using Python’sprint will only print the tracer;to print the values being mapped over, usejax.debug.print():
deff(x):jax.debug.print("jax.debug.print(x) ->{}",x)y=jnp.sin(x)jax.debug.print("jax.debug.print(y) ->{}",y)returnyxs=jnp.arange(3.)result=jax.vmap(f)(xs)
jax.debug.print(x) -> 0.0jax.debug.print(x) -> 1.0jax.debug.print(x) -> 2.0jax.debug.print(y) -> 0.0jax.debug.print(y) -> 0.8414709568023682jax.debug.print(y) -> 0.9092974066734314
Here’s the result withjax.lax.map(), which is a sequential map rather than avectorization:
result=jax.lax.map(f,xs)
jax.debug.print(y) -> 0.0jax.debug.print(x) -> 0.0jax.debug.print(y) -> 0.8414709568023682jax.debug.print(x) -> 1.0jax.debug.print(y) -> 0.9092974066734314jax.debug.print(x) -> 2.0
Notice the order is different, asjax.vmap() andjax.lax.map() compute the same results in different ways. When debugging, the evaluation order details are exactly what you may need to inspect.
Below is an example withjax.grad(), wherejax.debug.print() only prints the forward pass. In this case, the behavior is similar to Python’sprint(), but it’s consistent if you applyjax.jit() during the call.
deff(x):jax.debug.print("jax.debug.print(x) ->{}",x)returnx**2result=jax.grad(f)(1.)
jax.debug.print(x) -> 1.0
Sometimes, when the arguments don’t depend on one another, calls tojax.debug.print() may print them in a different order when staged out with a JAX transformation. If you need the original order, such asx:... first and theny:... second, add theordered=True parameter.
For example:
@jax.jitdeff(x,y):jax.debug.print("jax.debug.print(x) ->{}",x,ordered=True)jax.debug.print("jax.debug.print(y) ->{}",y,ordered=True)returnx+yf(1,2)
jax.debug.print(x) -> 1jax.debug.print(y) -> 2
Array(3, dtype=int32, weak_type=True)
To learn more aboutjax.debug.print() and its Sharp Bits, refer toAdvanced debugging.
jax.debug.breakpoint forpdb-like debugging#
Summary: Usejax.debug.breakpoint() to pause the execution of your JAX program to inspect values.
To pause your compiled JAX program during certain points during debugging, you can usejax.debug.breakpoint(). The prompt is similar to Pythonpdb, and it allows you to inspect the values in the call stack. In fact,jax.debug.breakpoint() is an application ofjax.debug.callback() that captures information about the call stack.
To print all available commands during abreakpoint debugging session, use thehelp command. (Full debugger commands, the Sharp Bits, its strengths and limitations are covered inAdvanced debugging.)
Here is an example of what a debugger session might look like:
@jax.jitdeff(x):y,z=jnp.sin(x),jnp.cos(x)jax.debug.breakpoint()returny*zf(2.)# ==> Pauses during execution

For value-dependent breakpointing, you can use runtime conditionals likejax.lax.cond():
defbreakpoint_if_nonfinite(x):is_finite=jnp.isfinite(x).all()deftrue_fn(x):passdeffalse_fn(x):jax.debug.breakpoint()jax.lax.cond(is_finite,true_fn,false_fn,x)@jax.jitdeff(x,y):z=x/ybreakpoint_if_nonfinite(z)returnzf(2.,1.)# ==> No breakpoint
Array(2., dtype=float32, weak_type=True)
f(2.,0.)# ==> Pauses during execution
jax.debug.callback for more control during debugging#
Bothjax.debug.print() andjax.debug.breakpoint() are implemented usingthe more flexiblejax.debug.callback(), which gives greater control over thehost-side logic executed via a Python callback.It is compatible withjax.jit(),jax.vmap(),jax.grad() and othertransformations (refer to theFlavors of callback table inExternal callbacks for more information).
For example:
importloggingdeflog_value(x):logging.warning(f'Logged value:{x}')@jax.jitdeff(x):jax.debug.callback(log_value,x)returnxf(1.0);
WARNING:root:Logged value: 1.0
This callback is compatible with other transformations, includingjax.vmap() andjax.grad():
x=jnp.arange(5.0)jax.vmap(f)(x);
WARNING:root:Logged value: 0.0WARNING:root:Logged value: 1.0WARNING:root:Logged value: 2.0WARNING:root:Logged value: 3.0WARNING:root:Logged value: 4.0
jax.grad(f)(1.0);
WARNING:root:Logged value: 1.0
This can makejax.debug.callback() useful for general-purpose debugging.
You can learn more aboutjax.debug.callback() and other kinds of JAX callbacks inExternal callbacks.
Read more inCompiled prints and breakpoints.
Functional error checks withjax.experimental.checkify#
Summary: Checkify lets you addjit-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use thecheckify.checkify transformation together with the assert-likecheckify.check function to add runtime checks to JAX code:
fromjax.experimentalimportcheckifyimportjaximportjax.numpyasjnpdeff(x,i):checkify.check(i>=0,"index needs to be non-negative!")y=x[i]z=jnp.sin(y)returnzjittable_f=checkify.checkify(f)err,z=jax.jit(jittable_f)(jnp.ones((5,)),-1)print(err.get())# >> index needs to be non-negative! (check failed at <...>:6 (f))
You can also use checkify to automatically add common checks:
errors=checkify.user_checks|checkify.index_checks|checkify.float_checkschecked_f=checkify.checkify(f,errors=errors)err,z=checked_f(jnp.ones((5,)),100)err.throw()# ValueError: out-of-bounds indexing at <..>:7 (f)err,z=checked_f(jnp.ones((5,)),-1)err.throw()# ValueError: index needs to be non-negative! (check failed at <…>:6 (f))err,z=checked_f(jnp.array([jnp.inf,1]),0)err.throw()# ValueError: nan generated by primitive sin at <...>:8 (f)
Read more inThe checkify transformation.
Throwing Python errors with JAX’s debug flags#
Summary: Enable thejax_debug_nans flag to automatically detect when NaNs are produced injax.jit-compiled code (but not injax.pmap orjax.pjit-compiled code) and enable thejax_disable_jit flag to disable JIT-compilation, enabling use of traditional Python debugging tools likeprint andpdb.
importjaxjax.config.update("jax_debug_nans",True)deff(x,y):returnx/yjax.jit(f)(0.,0.)# ==> raises FloatingPointError exception!
Read more inJAX debugging flags.
Next steps#
Check out theAdvanced debugging to learn more about debugging in JAX.
