Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.debug.callback

Contents

jax.debug.callback#

jax.debug.callback(callback,*args,ordered=False,partitioned=False,**kwargs)[source]#

Calls a stageable Python callback.

For more explanation, seeExternal Callbacks.

jax.debug.callback enables you to pass in a Python function that can be calledinside of a staged JAX program. Ajax.debug.callback follows existing JAXtransformationpure operational semantics, which are therefore unaware ofside-effects. This means the effect could be dropped, duplicated, orpotentially reordered in the presence of higher-order primitives andtransformations.

We want this behavior because we’d likejax.debug.callback to be “innocuous”,i.e. we want these primitives to change the JAX computation as little aspossible while revealing as much about them as possible, such as which partsof the computation are duplicated or dropped.

Parameters:
  • callback (Callable[...,None]) – A Python callable returning None.

  • *args (Any) – The positional arguments to the callback.

  • ordered (bool) – A keyword only argument used to indicate whether or not thestaged out computation will enforce ordering of this callback w.r.t.other ordered callbacks.

  • partitioned (bool) – If True, then print local shards only; this option avoids anall-gather of the operands. If False, print with logical operands; thisoption requires an all-gather of operands first.

  • **kwargs (Any) – The keyword arguments to the callback.

Returns:

None

Return type:

None

See also

Contents

[8]ページ先頭

©2009-2025 Movatter.jp