jax.debug module
Contents
jax.debug module#
Runtime value debugging utilities#
Compiled prints and breakpoints describes how to make use of JAX’s runtime valuedebugging features.
| Calls a stageable Python callback. |
| Prints values and works in staged out JAX functions. |
| Enters a breakpoint at a point in a program. |
Sharding debugging utilities#
Functions that enable inspecting and visualizing array shardings inside (and outside)staged functions.
| Enables inspecting array sharding inside JIT-ted functions. |
| Visualizes an array's sharding. |
| Visualizes a |
Contents
