Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.tree_util module

jax.tree_util module#

Utilities for working with tree-like container data structures.

This module provides a small set of utility functions for working with tree-likedata structures, such as nested tuples, lists, and dicts. We call thesestructures pytrees. They are trees in that they are defined recursively (anynon-pytree is a pytree, i.e. a leaf, and any pytree of pytrees is a pytree) andcan be operated on recursively (object identity equivalence is not preserved bymapping operations, and the structures cannot contain reference cycles).

The set of Python types that are considered pytree nodes (e.g. that can bemapped over, rather than treated as leaves) is extensible. There is a singlemodule-level registry of types, and class hierarchy is ignored. By registering anew pytree node type, that type in effect becomes transparent to the utilityfunctions in this file.

The primary purpose of this module is to enable the interoperability betweenuser defined data structures and JAX transformations (e.g.jit). This is notmeant to be a general purpose tree-like data structure handling library.

See theJAX pytrees notefor examples.

List of Functions#

Partial(func, *args, **kw)

A version of functools.partial that works in pytrees.

all_leaves(iterable[, is_leaf])

Tests whether all elements in the given iterable are all leaves.

register_dataclass(nodetype[, data_fields, ...])

Extends the set of types that are considered internal nodes in pytrees.

register_pytree_node(nodetype, flatten_func, ...)

Extends the set of types that are considered internal nodes in pytrees.

register_pytree_node_class(cls)

Extends the set of types that are considered internal nodes in pytrees.

register_pytree_with_keys(nodetype, ...[, ...])

Extends the set of types that are considered internal nodes in pytrees.

register_pytree_with_keys_class(cls)

Extends the set of types that are considered internal nodes in pytrees.

register_static(cls)

Registerscls as a pytree with no leaves.

tree_flatten_with_path(tree[, is_leaf, ...])

Alias ofjax.tree.flatten_with_path().

tree_leaves_with_path(tree[, is_leaf, ...])

Alias ofjax.tree.leaves_with_path().

tree_map_with_path(f, tree, *rest[, ...])

Alias ofjax.tree.map_with_path().

treedef_children(treedef)

Return a list of treedefs for immediate children

treedef_is_leaf(treedef)

Return True if the treedef represents a leaf.

treedef_tuple(treedefs)

Makes a tuple treedef from an iterable of child treedefs.

KeyEntry

Type variable.

KeyPath

Built-in immutable sequence.

keystr(keys, *[, simple, separator])

Helper to pretty-print a tuple of keys.

Legacy APIs#

These APIs are now accessed viajax.tree.

tree_all(tree, *[, is_leaf])

Alias ofjax.tree.all().

tree_broadcast(prefix_tree, full_tree[, is_leaf])

Alias ofjax.tree.broadcast().

tree_flatten(tree[, is_leaf])

Alias ofjax.tree.flatten().

tree_leaves(tree[, is_leaf])

Alias ofjax.tree.leaves().

tree_map(f, tree, *rest[, is_leaf])

Alias ofjax.tree.map().

tree_reduce(function, tree[, initializer, ...])

Alias ofjax.tree.reduce().

tree_reduce_associative(operation, tree, *)

Alias ofjax.tree.reduce_associative().

tree_structure(tree[, is_leaf])

Alias ofjax.tree.structure().

tree_transpose(outer_treedef, inner_treedef, ...)

Alias ofjax.tree.transpose().

tree_unflatten(treedef, leaves)

Alias ofjax.tree.unflatten().


[8]ページ先頭

©2009-2026 Movatter.jp