Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

External callbacks#

This tutorial outlines how you can use various callback functions, which allow JAX runtimes to execute Python code on the host. Examples of JAX callbacks arejax.pure_callback,jax.experimental.io_callback andjax.debug.callback. You can use them even while running under JAX transformations, includingjit(),vmap(),grad().

Why callbacks?#

A callback routine is a way to performhost-side execution of code at runtime.As a simple example, suppose you’d like to print thevalue of some variable during the course of a computation.Using a simple Pythonprint() statement, it looks like this:

importjax@jax.jitdeff(x):y=x+1print("intermediate value:{}".format(y))returny*2result=f(2)
intermediate value: JitTracer<~int32[]>

What is printed is not the runtime value, but the trace-time abstract value (if you’re not familiar withtracing in JAX, a good primer can be found inTracing.

To print the value at runtime, you need a callback, for examplejax.debug.print() (you can learn more about debugging inIntroduction to debugging):

@jax.jitdeff(x):y=x+1jax.debug.print("intermediate value:{}",y)returny*2result=f(2)
intermediate value: 3

This works by passing the runtime value ofy as a CPUjax.Array back to the host process, where the host can print it.

Flavors of callback#

In earlier versions of JAX, there was only one kind of callback available, implemented injax.experimental.host_callback(). Thehost_callback routines had some deficiencies, and are now deprecated in favor of several callbacks designed for different situations:

(Thejax.debug.print() function you used previously is a wrapper aroundjax.debug.callback()).

From the user perspective, these three flavors of callback are mainly distinguished by what transformations and compiler optimizations they allow.

callback function

supports return value

jit

vmap

grad

scan/while_loop

guaranteed execution

jax.pure_callback()

❌¹

jax.experimental.io_callback()

✅/❌²

✅³

jax.debug.callback()

¹jax.pure_callback can be used withcustom_jvp to make it compatible with autodiff

²jax.experimental.io_callback is compatible withvmap only ifordered=False.

³ Note thatvmap ofscan/while_loop ofio_callback has complicated semantics, and its behavior may change in future releases.

Exploringpure_callback#

jax.pure_callback() is generally the callback function you should reach for when you want host-side execution of a pure function: i.e. a function that has no side-effects (such as printing values, reading data from disk, updating a global state, etc.).

The function you pass tojax.pure_callback() need not actually be pure, but it will be assumed pure by JAX’s transformations and higher-order functions, which means that it may be silently elided or called multiple times.

importjaximportjax.numpyasjnpimportnumpyasnpdeff_host(x):# call a numpy (not jax.numpy) operation:returnnp.sin(x).astype(x.dtype)deff(x):result_shape=jax.ShapeDtypeStruct(x.shape,x.dtype)returnjax.pure_callback(f_host,result_shape,x,vmap_method='sequential')x=jnp.arange(5.0)f(x)
Array([ 0.       ,  0.841471 ,  0.9092974,  0.14112  , -0.7568025],      dtype=float32)

Becausepure_callback can be elided or duplicated, it is compatible out-of-the-box with transformations likejit as well as higher-order primitives likescan andwhile_loop:”

jax.jit(f)(x)
Array([ 0.       ,  0.841471 ,  0.9092974,  0.14112  , -0.7568025],      dtype=float32)
defbody_fun(_,x):return_,f(x)jax.lax.scan(body_fun,None,jnp.arange(5.0))[1]
Array([ 0.       ,  0.841471 ,  0.9092974,  0.14112  , -0.7568025],      dtype=float32)

Because we specified avmap_method in thepure_callback function call, it will alsobe compatible withvmap:

jax.vmap(f)(x)
Array([ 0.       ,  0.841471 ,  0.9092974,  0.14112  , -0.7568025],      dtype=float32)

However, because there is no way for JAX to introspect the content of the callback,pure_callback has undefined autodiff semantics:

jax.grad(f)(x)
ValueError: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients.

For an example of usingpure_callback withjax.custom_jvp(), seeExample:pure_callback withcustom_jvp below.

By design functions passed topure_callback are treated as if they have no side-effects: one consequence of this is that if the output of the function is not used, the compiler may eliminate the callback entirely:

defprint_something():print('printing something')returnnp.int32(0)@jax.jitdeff1():returnjax.pure_callback(print_something,np.int32(0))f1();
printing something
@jax.jitdeff2():jax.pure_callback(print_something,np.int32(0))return1.0f2();

Inf1, the output of the callback is used in the return value of the function, so the callback is executed and we see the printed output.Inf2 on the other hand, the output of the callback is unused, and so the compiler notices this and eliminates the function call. These are the correct semantics for a callback to a function with no side-effects.

pure_callback and exceptions#

In the context of JAX transformations, Python runtime exceptions should be considered side-effects:this means that intentionally raising an error within apure_callback breaks the API contract,and the behavior of the resulting program is undefined. In particular, the manner in whichsuch a program halts will generally depend on the backend, and the details of that behavior maychange in future releases.

Additionally, passing impure functions topure_callback may result in unexpected behavior duringtransformations likejax.jit() orjax.vmap(), because the transformation rules forpure_callback are defined under the assumption that the callback function is pure. Here’s onesimple example of an impure callback behaving unexpectedly undervmap:

importjaximportjax.numpyasjnpdefraise_via_callback(x):def_raise(x):raiseValueError(f"value of x is{x}")returnjax.pure_callback(_raise,x,x)defraise_if_negative(x):returnjax.lax.cond(x<0,raise_via_callback,lambdax:x,x)x_batch=jnp.arange(4)[raise_if_negative(x)forxinx_batch]# does not raisejax.vmap(raise_if_negative)(x_batch)# ValueError: value of x is 0

To avoid this and similar unexpected behavior, we recommend not attempting to usepure_callback to raise runtime errors.

Exploringio_callback#

In contrast tojax.pure_callback(),jax.experimental.io_callback() is explicitly meant to be used with impure functions, i.e. functions that do have side-effects.

As an example, here is a callback to a global host-side numpy random generator. This is an impure operation because a side-effect of generating a random number in numpy is that the random state is updated (Please note that this is meant as a toy example ofio_callback and not necessarily a recommended way of generating random numbers in JAX!).

fromjax.experimentalimportio_callbackfromfunctoolsimportpartialglobal_rng=np.random.default_rng(0)defhost_side_random_like(x):"""Generate a random array like x using the global_rng state"""# We have two side-effects here:# - printing the shape and dtype# - calling global_rng, thus updating its stateprint(f'generating{x.dtype}{list(x.shape)}')returnglobal_rng.uniform(size=x.shape).astype(x.dtype)@jax.jitdefnumpy_random_like(x):returnio_callback(host_side_random_like,x,x)x=jnp.zeros(5)numpy_random_like(x)
generating float32[5]
Array([0.6369617 , 0.26978672, 0.04097353, 0.01652764, 0.8132702 ],      dtype=float32)

Theio_callback is compatible withvmap by default:

jax.vmap(numpy_random_like)(x)
generating float32[]generating float32[]generating float32[]generating float32[]generating float32[]
Array([0.91275555, 0.60663575, 0.72949654, 0.543625  , 0.9350724 ],      dtype=float32)

Note, however, that this may execute the mapped callbacks in any order. So, for example, if you ran this on a GPU, the order of the mapped outputs might differ from run to run.

If it is important that the order of callbacks be preserved, you can setordered=True, in which case attempting tovmap will raise an error:

@jax.jitdefnumpy_random_like_ordered(x):returnio_callback(host_side_random_like,x,x,ordered=True)jax.vmap(numpy_random_like_ordered)(x)
ValueError: Cannot `vmap` ordered IO callback.

On the other hand,scan andwhile_loop work withio_callback regardless of whether ordering is enforced:

defbody_fun(_,x):return_,numpy_random_like_ordered(x)jax.lax.scan(body_fun,None,jnp.arange(5.0))[1]
generating float32[]generating float32[]generating float32[]generating float32[]generating float32[]
Array([0.81585354, 0.0027385 , 0.8574043 , 0.03358557, 0.72965544],      dtype=float32)

Likepure_callback,io_callback fails under automatic differentiation if it is passed a differentiated variable:

jax.grad(numpy_random_like)(x)
ValueError: IO callbacks do not support JVP.

However, if the callback is not dependent on a differentiated variable, it will execute:

@jax.jitdeff(x):io_callback(lambda:print('hello'),None)returnxjax.grad(f)(1.0);
hello

Unlikepure_callback, the compiler will not remove the callback execution in this case, even though the output of the callback is unused in the subsequent computation.

Exploringdebug.callback#

Bothpure_callback andio_callback enforce some assumptions about the purity of the function they’re calling, and limit in various ways what JAX transforms and compilation machinery may do.debug.callback essentially assumesnothing about the callback function, such that the action of the callback reflects exactly what JAX is doing during the course of a program. Further,debug.callbackcannot return any value to the program.

fromjaximportdebugdeflog_value(x):# This could be an actual logging call; we'll use# print() for demonstrationprint("log:",x)@jax.jitdeff(x):debug.callback(log_value,x)returnxf(1.0);
log: 1.0

The debug callback is compatible withvmap:

x=jnp.arange(5.0)jax.vmap(f)(x);
log: 0.0log: 1.0log: 2.0log: 3.0log: 4.0

And is also compatible withgrad and other autodiff transformations

jax.grad(f)(1.0);
log: 1.0

This can makedebug.callback more useful for general-purpose debugging than eitherpure_callback orio_callback.

Example:pure_callback withcustom_jvp#

One powerful way to take advantage ofjax.pure_callback() is to combine it withjax.custom_jvp. (Refer toCustom derivative rules for JAX-transformable Python functions for more details onjax.custom_jvp()).

Suppose you want to create a JAX-compatible wrapper for a scipy or numpy function that is not yet available in thejax.scipy orjax.numpy wrappers.

Here, we’ll consider creating a wrapper for the Bessel function of the first kind, available inscipy.special.jv.You can start by defining a straightforwardpure_callback():

importjaximportjax.numpyasjnpimportscipy.specialdefjv(v,z):v,z=jnp.asarray(v),jnp.asarray(z)# Require the order v to be integer type: this simplifies# the JVP rule below.assertjnp.issubdtype(v.dtype,jnp.integer)# Promote the input to inexact (float/complex).# Note that jnp.result_type() accounts for the enable_x64 flag.z=z.astype(jnp.result_type(float,z.dtype))# Wrap scipy function to return the expected dtype._scipy_jv=lambdav,z:scipy.special.jv(v,z).astype(z.dtype)# Define the expected shape & dtype of output.result_shape_dtype=jax.ShapeDtypeStruct(shape=jnp.broadcast_shapes(v.shape,z.shape),dtype=z.dtype)# Use vmap_method="broadcast_all" because scipy.special.jv handles broadcasted inputs.returnjax.pure_callback(_scipy_jv,result_shape_dtype,v,z,vmap_method="broadcast_all")

This lets us call intoscipy.special.jv() from transformed JAX code, including when transformed byjit() andvmap():

fromfunctoolsimportpartialj1=partial(jv,1)z=jnp.arange(5.0)
print(j1(z))
[ 0.          0.44005057  0.5767248   0.33905897 -0.06604332]

Here is the same result withjit():

print(jax.jit(j1)(z))
[ 0.          0.44005057  0.5767248   0.33905897 -0.06604332]

And here is the same result again withvmap():

print(jax.vmap(j1)(z))
[ 0.          0.44005057  0.5767248   0.33905897 -0.06604332]

However, if you callgrad(), you will get an error because there is no autodiff rule defined for this function:

jax.grad(j1)(z)
ValueError: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients.

Let’s define a custom gradient rule for this. Looking at the definition of theBessel Function of the First Kind, you find that there is a relatively straightforward recurrence relationship for the derivative with respect to the argumentz:

\[\begin{split}d J_\nu(z) = \left\{\begin{eqnarray}-J_1(z),\ &\nu=0\\[J_{\nu - 1}(z) - J_{\nu + 1}(z)]/2,\ &\nu\ne 0\end{eqnarray}\right.\end{split}\]

The gradient with respect to\(\nu\) is more complicated, but since we’ve restricted thev argument to integer types you don’t need to worry about its gradient for the sake of this example.

You can usejax.custom_jvp() to define this automatic differentiation rule for your callback function:

jv=jax.custom_jvp(jv)@jv.defjvpdef_jv_jvp(primals,tangents):v,z=primals_,z_dot=tangents# Note: v_dot is always 0 because v is integer.jv_minus_1,jv_plus_1=jv(v-1,z),jv(v+1,z)djv_dz=jnp.where(v==0,-jv_plus_1,0.5*(jv_minus_1-jv_plus_1))returnjv(v,z),z_dot*djv_dz

Now computing the gradient of your function will work correctly:

j1=partial(jv,1)print(jax.grad(j1)(2.0))
-0.06447162

Further, since we’ve defined your gradient in terms ofjv itself, JAX’s architecture means that you get second-order and higher derivatives for free:

jax.hessian(j1)(2.0)
Array(-0.4003078, dtype=float32, weak_type=True)

Keep in mind that although this all works correctly with JAX, each call to your callback-basedjv function will result in passing the input data from the device to the host, and passing the output ofscipy.special.jv() from the host back to the device.

When running on accelerators like GPU or TPU, this data movement and host synchronization can lead to significant overhead each timejv is called.

However, if you are running JAX on a single CPU (where the “host” and “device” are on the same hardware), JAX will generally do this data transfer in a fast, zero-copy fashion, making this pattern a relatively straightforward way to extend JAX’s capabilities.


[8]ページ先頭

©2009-2025 Movatter.jp