jax.debug.callback
Contents
jax.debug.callback#
- jax.debug.callback(callback,*args,ordered=False,partitioned=False,**kwargs)[source]#
Calls a stageable Python callback.
For more explanation, seeExternal Callbacks.
jax.debug.callbackenables you to pass in a Python function that can be calledinside of a staged JAX program. Ajax.debug.callbackfollows existing JAXtransformationpure operational semantics, which are therefore unaware ofside-effects. This means the effect could be dropped, duplicated, orpotentially reordered in the presence of higher-order primitives andtransformations.We want this behavior because we’d like
jax.debug.callbackto be “innocuous”,i.e. we want these primitives to change the JAX computation as little aspossible while revealing as much about them as possible, such as which partsof the computation are duplicated or dropped.- Parameters:
callback (Callable[...,None]) – A Python callable returning None.
*args (Any) – The positional arguments to the callback.
ordered (bool) – A keyword only argument used to indicate whether or not thestaged out computation will enforce ordering of this callback w.r.t.other ordered callbacks.
partitioned (bool) – If True, then print local shards only; this option avoids anall-gather of the operands. If False, print with logical operands; thisoption requires an all-gather of operands first.
**kwargs (Any) – The keyword arguments to the callback.
- Returns:
None
- Return type:
None
See also
jax.experimental.io_callback(): callback designed for impure functions.jax.pure_callback(): callback designed for pure functions.jax.debug.print(): callback designed for printing.
