jax.tree module
Contents
jax.tree module#
Utilities for working with tree-like container data structures.
Thejax.tree namespace contains aliases of utilities fromjax.tree_util.
List of Functions#
| Call all() over the leaves of a tree. |
| Broadcasts a tree prefix into the full structure of a given tree. |
| Flattens a pytree. |
| Flattens a pytree like |
| Gets the leaves of a pytree. |
| Gets the leaves of a pytree like |
| Maps a multi-input function over pytree args to produce a new pytree. |
| Maps a multi-input function over pytree key path and args to produce a new pytree. |
| Call reduce() over the leaves of a tree. |
| Perform a reduction over a pytree with an associative binary operation. |
| Convenience wrapper to declare a static pytree attribute. |
| Gets the treedef for a pytree. |
| Transform a tree having tree structure (outer, inner) into one having structure (inner, outer). |
| Reconstructs a pytree from the treedef and the leaves. |
