JAX debugging flags
Contents
JAX debugging flags#
JAX offers flags and context managers that enable catching errors more easily.
jax_debug_nans configuration option and context manager#
Summary: Enable thejax_debug_nans flag to automatically detect when NaNs are produced injax.jit-compiled code.
jax_debug_nans is a JAX flag that when enabled, will cause computations to error-out immediately on production of a NaN. Switching this option on adds a NaN check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an@jax.jit.
For code under an@jax.jit, the output of every@jax.jit function is checked and if a NaN is present it will re-run the function in de-optimized op-by-op mode, effectively removing one level of@jax.jit at a time.
There could be tricky situations that arise, like NaNs that only occur under a@jax.jit but don’t get produced in de-optimized mode. In that case you’ll see a warning message print out but your code will continue to execute.
If the NaNs are being produced in the backward pass of a gradient evaluation, when an exception is raised several frames up in the stack trace you will be in the backward_pass function, which is essentially a simple jaxpr interpreter that walks the sequence of primitive operations in reverse.
Usage#
If you want to trace where NaNs are occurring in your functions or gradients, you can turn on the NaN-checker by doing one of:
running your code inside the
jax.debug_nanscontext manager, usingwithjax.debug_nans(True):;setting the
JAX_DEBUG_NANS=Trueenvironment variable;adding
jax.config.update("jax_debug_nans",True)near the top of your main file;adding
jax.config.parse_flags_with_absl()to your main file, then set the option using a command-line flag like--jax_debug_nans=True;
Example(s)#
importjaximportjax.numpyasjnpimporttracebackjax.config.update("jax_debug_nans",True)deff(x):w=3*jnp.square(x)returnjnp.log(-w)# The stack trace is very long so only print a couple lines.try:f(5.)exceptFloatingPointErrorase:print(traceback.format_exc(limit=2))
The NaN generated was caught. By running%debug, we can get a post-mortem debugger. This also works with functions under@jax.jit, as the example below shows.
:tags:[raises-exception]jax.jit(f)(5.)
When this code sees a NaN in the output of an@jax.jit function, it calls into the de-optimized code, so we still get a clear stack trace. And we can run a post-mortem debugger with%debug to inspect all the values to figure out the error.
Thejax.debug_nans context manager can be used to activate/deactivate NaN debugging. Since we activated it above withjax.config.update, let’s deactivate it:
withjax.debug_nans(False):print(jax.jit(f)(5.))
Strengths and limitations ofjax_debug_nans#
Strengths#
Easy to apply
Precisely detects where NaNs were produced
Throws a standard Python exception and is compatible with PDB postmortem
Limitations#
Re-running functions eagerly can be slow. You shouldn’t have the NaN-checker on if you’re not debugging, as it can introduce lots of device-host round-trips and performance regressions.
Errors on false positives (e.g. intentionally created NaNs)
jax_debug_infs configuration option and context manager#
jax_debug_infs works similarly tojax_debug_nans.jax_debug_infs often needs to be combined withjax_disable_jit, since Infs might not cascade to the output like NaNs. Alternatively,jax.experimental.checkify may be used to find Infs in intermediates.
Full documentation ofjax_debug_infs is forthcoming.
jax_disable_jit configuration option and context manager#
Summary: Enable thejax_disable_jit flag to disable JIT-compilation, enabling use of traditional Python debugging tools likeprint andpdb
jax_disable_jit is a JAX flag that when enabled, disables JIT-compilation throughout JAX (including in control flow functions likejax.lax.cond andjax.lax.scan).
Usage#
You can disable JIT-compilation by:
setting the
JAX_DISABLE_JIT=Trueenvironment variable;adding
jax.config.update("jax_disable_jit",True)near the top of your main file;adding
jax.config.parse_flags_with_absl()to your main file, then set the option using a command-line flag like--jax_disable_jit=True;
Examples#
importjaxjax.config.update("jax_disable_jit",True)deff(x):y=jnp.log(x)ifjnp.isnan(y):breakpoint()returnyjax.jit(f)(-2.)# ==> Enters PDB breakpoint!
Strengths and limitations ofjax_disable_jit#
Strengths#
Easy to apply
Enables use of Python’s built-in
breakpointandprintThrows standard Python exceptions and is compatible with PDB postmortem
Limitations#
Running functions without JIT-compilation can be slow
