jax.debug.print
Contents
jax.debug.print#
- jax.debug.print(fmt,*args,ordered=False,partitioned=False,skip_format_check=False,_use_logging=False,**kwargs)[source]#
Prints values and works in staged out JAX functions.
This function doesnot work with f-strings because formatting is delayed.So instead of
jax.debug.print(f"hello{bar}"), writejax.debug.print("hello{bar}",bar=bar).This function is a thin convenience wrapper around
jax.debug.callback().The implementation is essentially:defdebug_print(fmt:str,*args,**kwargs):jax.debug.callback(lambda*args,**kwargs:print(fmt.format(*args,**kwargs)),*args,**kwargs)
It may be useful to call
jax.debug.callback()directly instead of thisconvenience wrapper. For example, to get debug printing in logs, you mightusejax.debug.callback()together withlogging.log.- Parameters:
fmt (str) – A format string, e.g.
"hello{x}", that will be used to formatinput arguments, likestr.format. See the Python docs onstringformattingandformat string syntax.*args – A list of positional arguments to be formatted, as if passed to
fmt.format.ordered (bool) – A keyword only argument used to indicate whether or not the stagedout computation will enforce ordering of this
jax.debug.printw.r.t.other orderedjax.debug.printcalls.partitioned (bool) – If True, then print local shards only; this option avoids anall-gather of the operands. If False, print with logical operands; thisoption requires an all-gather of operands first.
skip_format_check (bool) – If True, the format string is not checked. This is usefulwhen using the function from inside a Pallas TPU kernel, where scalarsargs will be printed after the format string.
**kwargs – Additional keyword arguments to be formatted, as if passed to
fmt.format._use_logging (bool)
- Return type:
None
