jax.tree_util module
Contents
jax.tree_util module#
Utilities for working with tree-like container data structures.
This module provides a small set of utility functions for working with tree-likedata structures, such as nested tuples, lists, and dicts. We call thesestructures pytrees. They are trees in that they are defined recursively (anynon-pytree is a pytree, i.e. a leaf, and any pytree of pytrees is a pytree) andcan be operated on recursively (object identity equivalence is not preserved bymapping operations, and the structures cannot contain reference cycles).
The set of Python types that are considered pytree nodes (e.g. that can bemapped over, rather than treated as leaves) is extensible. There is a singlemodule-level registry of types, and class hierarchy is ignored. By registering anew pytree node type, that type in effect becomes transparent to the utilityfunctions in this file.
The primary purpose of this module is to enable the interoperability betweenuser defined data structures and JAX transformations (e.g.jit). This is notmeant to be a general purpose tree-like data structure handling library.
See theJAX pytrees notefor examples.
List of Functions#
| A version of functools.partial that works in pytrees. |
| Tests whether all elements in the given iterable are all leaves. |
| Extends the set of types that are considered internal nodes in pytrees. |
| Extends the set of types that are considered internal nodes in pytrees. |
Extends the set of types that are considered internal nodes in pytrees. | |
| Extends the set of types that are considered internal nodes in pytrees. |
Extends the set of types that are considered internal nodes in pytrees. | |
| Registerscls as a pytree with no leaves. |
| Alias of |
| Alias of |
| Alias of |
| Return a list of treedefs for immediate children |
| Return True if the treedef represents a leaf. |
| Makes a tuple treedef from an iterable of child treedefs. |
Type variable. | |
Built-in immutable sequence. | |
| Helper to pretty-print a tuple of keys. |
Legacy APIs#
These APIs are now accessed viajax.tree.
| Alias of |
| Alias of |
| Alias of |
| Alias of |
| Alias of |
| Alias of |
| Alias of |
| Alias of |
| Alias of |
| Alias of |
