Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

Compiled prints and breakpoints#

Thejax.debug package offers some useful tools for inspecting valuesinside of compiled functions.

Debugging withjax.debug.print and other debugging callbacks#

Summary: Usejax.debug.print() to print traced array values tostdout in compiled (e.g.jax.jit orjax.pmap-decorated) functions:

importjaximportjax.numpyasjnp@jax.jitdeff(x):jax.debug.print("🤯{x} 🤯",x=x)y=jnp.sin(x)jax.debug.print("🤯{y} 🤯",y=y)returnyf(2.)# Prints:# 🤯 2.0 🤯# 🤯 0.9092974662780762 🤯

With some transformations, likejax.grad andjax.vmap, you can use Python’s builtinprint function to print out numerical values. Butprint won’t work withjax.jit orjax.pmap because those transformations delay numerical evaluation. So usejax.debug.print instead!

Semantically,jax.debug.print is roughly equivalent to the following Python function

defdebug.print(fmt:str,*args:PyTree[Array],**kwargs:PyTree[Array])->None:print(fmt.format(*args,**kwargs))

except that it can be staged out and transformed by JAX. See theAPIreference for more details.

Note thatfmt cannot be an f-string because f-strings are formatted immediately, whereas forjax.debug.print, we’d like to delay formatting until later.

When to use “debug” print?#

You should usejax.debug.print for dynamic (i.e. traced) array values within JAX transformationslikejit,vmap, and others.For printing of static values (like array shapes or dtypes), you can use a normal Pythonprint statement.

Why “debug” print?#

In the name of debugging,jax.debug.print can reveal information abouthow computations are evaluated:

xs=jnp.arange(3.)deff(x):jax.debug.print("x:{}",x)y=jnp.sin(x)jax.debug.print("y:{}",y)returnyjax.vmap(f)(xs)# Prints: x: 0.0#         x: 1.0#         x: 2.0#         y: 0.0#         y: 0.841471#         y: 0.9092974jax.lax.map(f,xs)# Prints: x: 0.0#         y: 0.0#         x: 1.0#         y: 0.841471#         x: 2.0#         y: 0.9092974

Notice that the printed results are in different orders!

By revealing these inner-workings, the output ofjax.debug.print doesn’t respect JAX’s usual semantics guarantees, like thatjax.vmap(f)(xs) andjax.lax.map(f,xs) compute the same thing (in different ways). Yet these evaluation order details are exactly what we might want to see when debugging!

So usejax.debug.print for debugging, and not when semantics guarantees are important.

More examples ofjax.debug.print#

In addition to the above examples usingjit andvmap, here are a few more to have in mind.

Printing underjax.pmap#

Whenjax.pmap-ed,jax.debug.prints might be reordered!

xs=jnp.arange(2.)deff(x):jax.debug.print("x:{}",x)returnxjax.pmap(f)(xs)# Prints: x: 0.0#         x: 1.0# OR# Prints: x: 1.0#         x: 0.0

Printing underjax.grad#

Under ajax.grad,jax.debug.prints will only print on the forward pass:

deff(x):jax.debug.print("x:{}",x)returnx*2.jax.grad(f)(1.)# Prints: x: 1.0

This behavior is similar to how Python’s builtinprint works under ajax.grad. But by usingjax.debug.print here, the behavior is the same even if the caller applies ajax.jit.

To print on the backward pass, just use ajax.custom_vjp:

@jax.custom_vjpdefprint_grad(x):returnxdefprint_grad_fwd(x):returnx,Nonedefprint_grad_bwd(_,x_grad):jax.debug.print("x_grad:{}",x_grad)return(x_grad,)print_grad.defvjp(print_grad_fwd,print_grad_bwd)deff(x):x=print_grad(x)returnx*2.jax.grad(f)(1.)# Prints: x_grad: 2.0

Printing in other transformations#

jax.debug.print also works in other transformations likepjit.

More control withjax.debug.callback#

In fact,jax.debug.print is a thin convenience wrapper aroundjax.debug.callback, which can be used directly for greater control over string formatting, or even the kind of output.

Semantically,jax.debug.callback is roughly equivalent to the following Python function

defcallback(fun:Callable,*args:PyTree[Array],**kwargs:PyTree[Array])->None:fun(*args,**kwargs)returnNone

As withjax.debug.print, these callbacks should only be used for debugging output, like printing or plotting. Printing and plotting are pretty harmless, but if you use it for anything else its behavior might surprise you under transformations. For example, it’s not safe to usejax.debug.callback for timing operations, since callbacks might be reordered and asynchronous (see below).

Sharp bits#

Like most JAX APIs,jax.debug.print can cut you if you’re not careful.

Ordering of printed results#

When distinct calls tojax.debug.print involve arguments which don’t depend on one another, they might be reordered when staged out, e.g. byjax.jit:

@jax.jitdeff(x,y):jax.debug.print("x:{}",x)jax.debug.print("y:{}",y)returnx+yf(2.,3.)# Prints: x: 2.0#         y: 3.0# OR# Prints: y: 3.0#         x: 2.0

Why? Under the hood, the compiler gets a functional representation of the staged-out computation, where the imperative order of the Python function is lost and only data dependence remains. This change is invisible to users with functionally pure code, but in the presence of side-effects like printing, it’s noticeable.

To preserve the original order ofjax.debug.prints as written in your Python function, you can usejax.debug.print(...,ordered=True), which will ensure the relative order of prints is preserved. But usingordered=True will raise an error underjax.pmap and other JAX transformations involving parallelism, since ordering can’t be guaranteed under parallel execution.

Computation perturbation#

Addingjax.debug.print orjax.debug.breakpoint statements will change the computation that XLA is asked to compile. This can potentially result in numeric discrepancies compared to the same code without debug statements because XLA might perform different operation fusions during compilation. Keep this in mind when debugging numerical issues, as the act of adding debug statements might affect the behavior you’re trying to investigate.

Asynchronous callbacks#

Depending on the backend,jax.debug.prints may happen asynchronously, i.e. not in your main program thread. This means that values could be printed to your screen even after your JAX function has returned a value.

@jax.jitdeff(x):jax.debug.print("x:{}",x)returnxf(2.).block_until_ready()# <do something else># Prints: x: 2.

To block on thejax.debug.prints in a function, you can calljax.effects_barrier(), which will wait until any remaining side-effects in the function have completed as well:

@jax.jitdeff(x):jax.debug.print("x:{}",x)returnxf(2.).block_until_ready()jax.effects_barrier()# Prints: x: 2.# <do something else>

Performance impacts#

Unnecessary materialization#

Whilejax.debug.print was designed to have a minimal performance footprint, it can interfere with compiler optimizations and potentially affect the memory profile of your JAX programs.

deff(w,b,x):logits=w.dot(x)+bjax.debug.print("logits:{}",logits)returnjax.nn.relu(logits)

In this example, we are printing intermediate values in between a linear layer and the activation function. Compilers like XLA can perform fusion optimizations, which might avoid materializinglogits in memory. But when we usejax.debug.print onlogits, we are forcing those intermediates to be materialized, potentially slowing down the program and increasing memory usage.

Furthermore, when usingjax.debug.print withjax.pjit, a global synchronization occurs that will materialize values on a single device.

Callback overhead#

jax.debug.print inherently incurs communication between an accelerator and its host. The underlying mechanism differs from backend to backend (e.g. GPU vs TPU) but in all cases, we’ll need to copy the printed values from device to host. In the CPU case, this overhead is smaller.

Furthermore, when usingjax.debug.print withjax.pjit, a global synchronization occurs that adds some overhead.

Strengths and limitations ofjax.debug.print#

Strengths#

  • Print debugging is simple and intuitive

  • jax.debug.callback can be used for other innocuous side-effects

Limitations#

  • Adding print statements is a manual process

  • Can have performance impacts

Interactive inspection withjax.debug.breakpoint()#

Summary: Usejax.debug.breakpoint() to pause the execution of your JAX program to inspect values:

@jax.jitdeff(x):y,z=jnp.sin(x),jnp.cos(x)jax.debug.breakpoint()returny*zf(2.)# ==> Pauses during execution!

JAX debugger

jax.debug.breakpoint() is actually just an application ofjax.debug.callback(...) that captures information about the call stack. It has the same transformation behaviors asjax.debug.print as a result (e.g.vmap-ingjax.debug.breakpoint() unrolls it across the mapped axis).

Usage#

Callingjax.debug.breakpoint() in a compiled JAX function will pause your program when it hits the breakpoint. You’ll be presented with apdb-like prompt that allows you to inspect the values in the call stack. Unlikepdb, you will not be able to step through the execution, but you are allowed to resume it.

Debugger commands:

  • help - prints out available commands

  • p - evaluates an expression and prints its result

  • pp - evaluates an expression and pretty-prints its result

  • u(p) - go up a stack frame

  • d(own) - go down a stack frame

  • w(here)/bt - print out a backtrace

  • l(ist) - print out code context

  • c(ont(inue)) - resumes the execution of the program

  • q(uit)/exit - exits the program (does not work on TPU)

Examples#

Usage withjax.lax.cond#

When combined withjax.lax.cond, the debugger can become a useful tool for detectingnans orinfs.

defbreakpoint_if_nonfinite(x):is_finite=jnp.isfinite(x).all()deftrue_fn(x):passdeffalse_fn(x):jax.debug.breakpoint()lax.cond(is_finite,true_fn,false_fn,x)@jax.jitdeff(x,y):z=x/ybreakpoint_if_nonfinite(z)returnzf(2.,0.)# ==> Pauses during execution!

Sharp bits#

Becausejax.debug.breakpoint is a just an application ofjax.debug.callback, it has the samesharp bits asjax.debug.print, with a few more caveats:

  • jax.debug.breakpoint materializeseven more intermediates thanjax.debug.print because it forces materialization of all values in the call stack

  • jax.debug.breakpoint has more runtime overhead than ajax.debug.print because it has to potentially copy all the intermediate values in a JAX program from device to host.

Strengths and limitations ofjax.debug.breakpoint()#

Strengths#

  • Simple, intuitive and (somewhat) standard

  • Can inspect many values at the same time, up and down the call stack

Limitations#

  • Need to potentially use many breakpoints to pinpoint the source of an error

  • Materializes many intermediates


[8]ページ先頭

©2009-2025 Movatter.jp