Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.debug.print

Contents

jax.debug.print#

jax.debug.print(fmt,*args,ordered=False,partitioned=False,skip_format_check=False,_use_logging=False,**kwargs)[source]#

Prints values and works in staged out JAX functions.

This function doesnot work with f-strings because formatting is delayed.So instead ofjax.debug.print(f"hello{bar}"), writejax.debug.print("hello{bar}",bar=bar).

This function is a thin convenience wrapper aroundjax.debug.callback().The implementation is essentially:

defdebug_print(fmt:str,*args,**kwargs):jax.debug.callback(lambda*args,**kwargs:print(fmt.format(*args,**kwargs)),*args,**kwargs)

It may be useful to calljax.debug.callback() directly instead of thisconvenience wrapper. For example, to get debug printing in logs, you mightusejax.debug.callback() together withlogging.log.

Parameters:
  • fmt (str) – A format string, e.g."hello{x}", that will be used to formatinput arguments, likestr.format. See the Python docs onstringformattingandformat string syntax.

  • *args – A list of positional arguments to be formatted, as if passed tofmt.format.

  • ordered (bool) – A keyword only argument used to indicate whether or not the stagedout computation will enforce ordering of thisjax.debug.print w.r.t.other orderedjax.debug.print calls.

  • partitioned (bool) – If True, then print local shards only; this option avoids anall-gather of the operands. If False, print with logical operands; thisoption requires an all-gather of operands first.

  • skip_format_check (bool) – If True, the format string is not checked. This is usefulwhen using the function from inside a Pallas TPU kernel, where scalarsargs will be printed after the format string.

  • **kwargs – Additional keyword arguments to be formatted, as if passed tofmt.format.

  • _use_logging (bool)

Return type:

None

Contents

[8]ページ先頭

©2009-2025 Movatter.jp