Pytrees
Contents
Pytrees#
JAX has built-in support for objects that look like dictionaries (dicts) of arrays, or lists of lists of dicts, or other nested structures — in JAX these are called pytrees.This section will explain how to use them, provide useful code examples, and point out common “gotchas” and patterns.
For an explanation of how to create custom pytrees, seeCustom pytree nodes.
What is a pytree?#
A pytree is a container-like structure built out of container-like Python objects — “leaf” pytrees and/or more pytrees. A pytree can include lists, tuples, and dicts. A leaf is anything that’s not a pytree, such as an array, but a single leaf is also a pytree.
In the context of machine learning (ML), a pytree can contain:
Model parameters
Dataset entries
Reinforcement learning agent observations
When working with datasets, you can often come across pytrees (such as lists of lists of dicts).
Below is an example of a simple pytree. In JAX, you can usejax.tree.leaves(), to extract the flattened leaves from the trees, as demonstrated here:
importjaximportjax.numpyasjnpexample_trees=[[1,'a',object()],(1,(2,3),()),[1,{'k1':2,'k2':(3,4)},5],{'a':2,'b':(2,3)},jnp.array([1,2,3]),]# Print how many leaves the pytrees have.forpytreeinexample_trees:# This `jax.tree.leaves()` method extracts the flattened leaves from the pytrees.leaves=jax.tree.leaves(pytree)print(f"{repr(pytree):<45} has{len(leaves)} leaves:{leaves}")
[1, 'a', <object object at 0x7227811b8490>] has 3 leaves: [1, 'a', <object object at 0x7227811b8490>](1, (2, 3), ()) has 3 leaves: [1, 2, 3][1, {'k1': 2, 'k2': (3, 4)}, 5] has 5 leaves: [1, 2, 3, 4, 5]{'a': 2, 'b': (2, 3)} has 3 leaves: [2, 2, 3]Array([1, 2, 3], dtype=int32) has 1 leaves: [Array([1, 2, 3], dtype=int32)]Any tree-like structure built out of container-like Python objects can be treated as a pytree in JAX.Classes are considered container-like if they are in the pytree registry, which by default includes lists, tuples, and dicts. Any object whose type isnot in the pytree container registry will be treated as a leaf node in the tree.
The pytree registry can be extended to include user-defined container classes by registering the classwith functions that specify how to flatten the tree; seeCustom pytree nodes below.
Common pytree functions#
JAX provides a number of utilities to operate over pytrees. These can be found in thejax.tree_util subpackage;for convenience many of these have aliases in thejax.tree module.
Common function:jax.tree.map#
The most commonly used pytree function isjax.tree.map(). It works analogously to Python’s nativemap, but transparently operates over entire pytrees.
Here’s an example:
list_of_lists=[[1,2,3],[1,2],[1,2,3,4]]jax.tree.map(lambdax:x*2,list_of_lists)
[[2, 4, 6], [2, 4], [2, 4, 6, 8]]
jax.tree.map() also allows mapping aN-ary function over multiple arguments. For example:
another_list_of_lists=list_of_listsjax.tree.map(lambdax,y:x+y,list_of_lists,another_list_of_lists)
[[2, 4, 6], [2, 4], [2, 4, 6, 8]]
When using multiple arguments withjax.tree.map(), the structure of the inputs must exactly match. That is, lists must have the same number of elements, dicts must have the same keys, etc.
Example ofjax.tree.map with ML model parameters#
This example demonstrates how pytree operations can be useful when training a simplemulti-layer perceptron (MLP).
Begin with defining the initial model parameters:
importnumpyasnpdefinit_mlp_params(layer_widths):params=[]forn_in,n_outinzip(layer_widths[:-1],layer_widths[1:]):params.append(dict(weights=np.random.normal(size=(n_in,n_out))*np.sqrt(2/n_in),biases=np.ones(shape=(n_out,))))returnparamsparams=init_mlp_params([1,128,128,1])
Usejax.tree.map() to check the shapes of the initial parameters:
jax.tree.map(lambdax:x.shape,params)
[{'biases': (128,), 'weights': (1, 128)}, {'biases': (128,), 'weights': (128, 128)}, {'biases': (1,), 'weights': (128, 1)}]Next, define the functions for training the MLP model:
# Define the forward pass.defforward(params,x):*hidden,last=paramsforlayerinhidden:x=jax.nn.relu(x@layer['weights']+layer['biases'])returnx@last['weights']+last['biases']# Define the loss function.defloss_fn(params,x,y):returnjnp.mean((forward(params,x)-y)**2)# Set the learning rate.LEARNING_RATE=0.0001# Using the stochastic gradient descent, define the parameter update function.# Apply `@jax.jit` for JIT compilation (speed).@jax.jitdefupdate(params,x,y):# Calculate the gradients with `jax.grad`.grads=jax.grad(loss_fn)(params,x,y)# Note that `grads` is a pytree with the same structure as `params`.# `jax.grad` is one of many JAX functions that has# built-in support for pytrees.# This is useful - you can apply the SGD update using JAX pytree utilities.returnjax.tree.map(lambdap,g:p-LEARNING_RATE*g,params,grads)
Viewing the pytree definition of an object#
To view the pytree definition of an arbitraryobject for debugging purposes, you can use:
fromjax.tree_utilimporttree_structureprint(tree_structure(object))
PyTreeDef(*)
Pytrees and JAX transformations#
Many JAX functions, likejax.lax.scan(), operate over pytrees of arrays. In addition, all JAX function transformations can be applied to functions that accept as input and produce as output pytrees of arrays.
Some JAX function transformations take optional parameters that specify how certain input or output values should be treated (such as thein_axes andout_axes arguments tojax.vmap()). These parameters can also be pytrees, and their structure must correspond to the pytree structure of the corresponding arguments. In particular, to be able to “match up” leaves in these parameter pytrees with values in the argument pytrees, the parameter pytrees are often constrained to be tree prefixes of the argument pytrees.
For example, if you pass the following input tojax.vmap() (note that the input arguments to a function are considered a tuple):
vmap(f,in_axes=(a1,{"k1":a2,"k2":a3}))
then you can use the followingin_axes pytree to specify that only thek2 argument is mapped (axis=0), and the rest aren’t mapped over (axis=None):
vmap(f,in_axes=(None,{"k1":None,"k2":0}))
The optional parameter pytree structure must match that of the main input pytree. However, the optional parameters can optionally be specified as a “prefix” pytree, meaning that a single leaf value can be applied to an entire sub-pytree.
For example, if you have the samejax.vmap() input as above, but wish to only map over the dictionary argument, you can use:
vmap(f,in_axes=(None,0))# equivalent to (None, {"k1": 0, "k2": 0})
Alternatively, if you want every argument to be mapped, you can write a single leaf value that is applied over the entire argument tuple pytree:
vmap(f,in_axes=0)# equivalent to (0, {"k1": 0, "k2": 0})
This happens to be the defaultin_axes value forjax.vmap().
The same logic applies to other optional parameters that refer to specific input or output values of a transformed function, such asout_axes injax.vmap().
Explicit key paths#
In a pytree each leaf has akey path. A key path for a leaf is alist ofkeys, where the length of the list is equal to the depth of the leaf in the pytree . Eachkey is ahashable object that represents an index into the corresponding pytree node type. The type of the key depends on the pytree node type; for example, the type of keys fordicts is different from the type of keys fortuples.
For built-in pytree node types, the set of keys for any pytree node instance is unique. For a pytree comprising nodes with this property, the key path for each leaf is unique.
JAX has the followingjax.tree_util.* methods for working with key paths:
jax.tree_util.tree_flatten_with_path(): Works similarly tojax.tree.flatten(), but returns key paths.jax.tree_util.tree_map_with_path(): Works similarly tojax.tree.map(), but the function also takes key paths as arguments.jax.tree_util.keystr(): Given a general key path, returns a reader-friendly string expression.
For example, one use case is to print debugging information related to a certain leaf value:
importcollectionsATuple=collections.namedtuple("ATuple",('name'))tree=[1,{'k1':2,'k2':(3,4)},ATuple('foo')]flattened,_=jax.tree_util.tree_flatten_with_path(tree)forkey_path,valueinflattened:print(f'Value of tree{jax.tree_util.keystr(key_path)}:{value}')
Value of tree[0]: 1Value of tree[1]['k1']: 2Value of tree[1]['k2'][0]: 3Value of tree[1]['k2'][1]: 4Value of tree[2].name: foo
To express key paths, JAX provides a few default key types for the built-in pytree node types, namely:
SequenceKey(idx:int): For lists and tuples.DictKey(key:Hashable): For dictionaries.GetAttrKey(name:str): Fornamedtuples and preferably custom pytree nodes (more in the next section)
You are free to define your own key types for your custom nodes. They will work withjax.tree_util.keystr() as long as their__str__() method is also overridden with a reader-friendly expression.
forkey_path,_inflattened:print(f'Key path of tree{jax.tree_util.keystr(key_path)}:{repr(key_path)}')
Key path of tree[0]: (SequenceKey(idx=0),)Key path of tree[1]['k1']: (SequenceKey(idx=1), DictKey(key='k1'))Key path of tree[1]['k2'][0]: (SequenceKey(idx=1), DictKey(key='k2'), SequenceKey(idx=0))Key path of tree[1]['k2'][1]: (SequenceKey(idx=1), DictKey(key='k2'), SequenceKey(idx=1))Key path of tree[2].name: (SequenceKey(idx=2), GetAttrKey(name='name'))
Common pytree gotchas#
This section covers some of the most common problems (“gotchas”) encountered when using JAX pytrees.
Mistaking pytree nodes for leaves#
A common gotcha to look out for is accidentally introducingtree nodes instead ofleaves:
a_tree=[jnp.zeros((2,3)),jnp.zeros((3,4))]# Try to make another pytree with ones instead of zeros.shapes=jax.tree.map(lambdax:x.shape,a_tree)jax.tree.map(jnp.ones,shapes)
[(Array([1., 1.], dtype=float32), Array([1., 1., 1.], dtype=float32)), (Array([1., 1., 1.], dtype=float32), Array([1., 1., 1., 1.], dtype=float32))]
What happened here is that theshape of an array is a tuple, which is a pytree node, with its elements as leaves. Thus, in the map, instead of callingjnp.ones on e.g.(2,3), it’s called on2 and3.
The solution will depend on the specifics, but there are two broadly applicable options:
Rewrite the code to avoid the intermediate
jax.tree.map().Convert the tuple into a NumPy array (
np.array) or a JAX NumPy array (jnp.array), which makes the entire sequence a leaf.
Handling ofNone byjax.tree_util#
jax.tree_util functions treatNone as the absence of a pytree node, not as a leaf:
jax.tree.leaves([None,None,None])
[]
To treatNone as a leaf, you can use theis_leaf argument:
jax.tree.leaves([None,None,None],is_leaf=lambdax:xisNone)
[None, None, None]
Common pytree patterns#
This section covers some of the most common patterns with JAX pytrees.
Transposing pytrees withjax.tree.map andjax.tree.transpose#
To transpose a pytree (turn a list of trees into a tree of lists), JAX has two functions:jax.tree.map() (more basic) andjax.tree.transpose() (more flexible, complex and verbose).
Option 1: Usejax.tree.map(). Here’s an example:
deftree_transpose(list_of_trees):""" Converts a list of trees of identical structure into a single tree of lists. """returnjax.tree.map(lambda*xs:list(xs),*list_of_trees)# Convert a dataset from row-major to column-major.episode_steps=[dict(t=1,obs=3),dict(t=2,obs=4)]tree_transpose(episode_steps)
{'obs': [3, 4], 't': [1, 2]}Option 2: For more complex transposes, usejax.tree.transpose(), which is more verbose, but allows you specify the structure of the inner and outer pytree for more flexibility. For example:
jax.tree.transpose(outer_treedef=jax.tree.structure([0foreinepisode_steps]),inner_treedef=jax.tree.structure(episode_steps[0]),pytree_to_transpose=episode_steps)
{'obs': [3, 4], 't': [1, 2]}Extending pytrees#
Material on extending pytrees has been moved toCustom pytree nodes.
