Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

Custom pytree nodes#

This section explains how in JAX you can extend the set of Python types that will be consideredinternal nodes in pytrees (pytree nodes) by usingjax.tree_util.register_pytree_node() withjax.tree.map().

Why would you need this? In the previous examples, pytrees were shown as lists, tuples, and dicts, with everything else as pytree leaves. This is because if you define your own container class, it will be considered to be a pytree leaf unless youregister it with JAX. This is also the case even if your container class has trees inside it. For example:

importjaxclassSpecial(object):def__init__(self,x,y):self.x=xself.y=yjax.tree.leaves([Special(0,1),Special(2,4),])
[<__main__.Special at 0x71ad404f0b00>, <__main__.Special at 0x71ad40ce0f20>]

Accordingly, if you try to use ajax.tree.map() expecting the leaves to be elements inside the container, you will get an error:

jax.tree.map(lambdax:x+1,[Special(0,1),Special(2,4)])
---------------------------------------------------------------------------TypeErrorTraceback (most recent call last)CellIn[2],line1---->1jax.tree.map(lambdax:x+1,2[3Special(0,1),4Special(2,4)5])File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/tree.py:155, inmap(f, tree, is_leaf, *rest)115defmap(f:Callable[...,Any],116tree:Any,117*rest:Any,118is_leaf:Callable[[Any],bool]|None=None)->Any:119"""Maps a multi-input function over pytree args to produce a new pytree.120121   Args:   (...)    153     - :func:`jax.tree.reduce`154   """-->155returntree_util.tree_map(f,tree,*rest,is_leaf=is_leaf)File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/tree_util.py:361, intree_map(f, tree, is_leaf, *rest)359leaves,treedef=tree_flatten(tree,is_leaf)360all_leaves=[leaves]+[treedef.flatten_up_to(r)forrinrest]--> 361 return treedef.unflatten(f(*xs) for xs inzip(*all_leaves))File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/tree_util.py:361, in<genexpr>(.0)359leaves,treedef=tree_flatten(tree,is_leaf)360all_leaves=[leaves]+[treedef.flatten_up_to(r)forrinrest]--> 361 return treedef.unflatten(f(*xs) for xs inzip(*all_leaves))Cell In[2], line 1, in<lambda>(x)---->1jax.tree.map(lambdax:x+1,2[3Special(0,1),4Special(2,4)5])TypeError: unsupported operand type(s) for +: 'Special' and 'int'

As a solution, JAX allows to extend the set of types to be considered internal pytree nodes through a global registry of types. Additionally, the values of registered types are traversed recursively.

First, register a new type usingjax.tree_util.register_pytree_node():

fromjax.tree_utilimportregister_pytree_nodeclassRegisteredSpecial(Special):def__repr__(self):return"RegisteredSpecial(x={}, y={})".format(self.x,self.y)defspecial_flatten(v):"""Specifies a flattening recipe.  Params:    v: The value of the registered type to flatten.  Returns:    A pair of an iterable with the children to be flattened recursively,    and some opaque auxiliary data to pass back to the unflattening recipe.    The auxiliary data is stored in the treedef for use during unflattening.    The auxiliary data could be used, for example, for dictionary keys.  """children=(v.x,v.y)aux_data=Nonereturn(children,aux_data)defspecial_unflatten(aux_data,children):"""Specifies an unflattening recipe.  Params:    aux_data: The opaque data that was specified during flattening of the      current tree definition.    children: The unflattened children  Returns:    A reconstructed object of the registered type, using the specified    children and auxiliary data.  """returnRegisteredSpecial(*children)# Global registrationregister_pytree_node(RegisteredSpecial,special_flatten,# Instruct JAX what are the children nodes.special_unflatten# Instruct JAX how to pack back into a `RegisteredSpecial`.)

Now you can traverse the special container structure:

jax.tree.map(lambdax:x+1,[RegisteredSpecial(0,1),RegisteredSpecial(2,4),])
[RegisteredSpecial(x=1, y=2), RegisteredSpecial(x=3, y=5)]

Alternatively, you can define appropriatetree_flatten andtree_unflatten methodson your class and decorate it withregister_pytree_node_class():

fromjax.tree_utilimportregister_pytree_node_class@register_pytree_node_classclassRegisteredSpecial2(Special):def__repr__(self):return"RegisteredSpecial2(x={}, y={})".format(self.x,self.y)deftree_flatten(self):children=(self.x,self.y)aux_data=Nonereturn(children,aux_data)@classmethoddeftree_unflatten(cls,aux_data,children):returncls(*children)defshow_example(structured):flat,tree=structured.tree_flatten()unflattened=RegisteredSpecial2.tree_unflatten(tree,flat)print(f"{structured=}\n{flat=}\n{tree=}\n{unflattened=}")show_example(RegisteredSpecial2(1.,2.))
structured=RegisteredSpecial2(x=1.0, y=2.0)  flat=(1.0, 2.0)  tree=None  unflattened=RegisteredSpecial2(x=1.0, y=2.0)

Modern Python comes equipped with helpful tools to make defining containers easier. Some will work with JAX out-of-the-box, but others require more care.

For instance, a PythonNamedTuple subclass doesn’t need to be registered to be considered a pytree node type:

fromtypingimportNamedTuple,AnyclassMyOtherContainer(NamedTuple):name:stra:Anyb:Anyc:Any# NamedTuple subclasses are handled as pytree nodes, so# this will work out-of-the-box.jax.tree.leaves([MyOtherContainer('Alice',1,2,3),MyOtherContainer('Bob',4,5,6)])
['Alice', 1, 2, 3, 'Bob', 4, 5, 6]

Notice that thename field now appears as a leaf, because all tuple elements are children. This is what happens when you don’t have to register the class the hard way.

When defining unflattening functions, in generalchildren should contain all thedynamic elements of the data structure (arrays, dynamic scalars, and pytrees), whileaux_data should contain all the static elements that will be rolled into thetreedefstructure. JAX sometimes needs to comparetreedef for equality, or compute its hashfor use in the JIT cache, and so care must be taken to ensure that the auxiliary dataspecified in the flattening recipe supports meaningful hashing and equality comparisons.

UnlikeNamedTuple subclasses, classes decorated with@dataclass are not automatically pytrees. However, they can be registered as pytrees using thejax.tree_util.register_dataclass() decorator:

fromdataclassesimportdataclassimportjax.numpyasjnpimportnumpyasnpimportfunctools@functools.partial(jax.tree_util.register_dataclass,data_fields=['a','b','c'],meta_fields=['name'])@dataclassclassMyDataclassContainer(object):name:stra:Anyb:Anyc:Any# MyDataclassContainer is now a pytree node.jax.tree.leaves([MyDataclassContainer('apple',5.3,1.2,jnp.zeros([4])),MyDataclassContainer('banana',np.array([3,4]),-1.,0.)])
[5.3, 1.2, Array([0., 0., 0., 0.], dtype=float32), array([3, 4]), -1.0, 0.0]

Notice that thename field does not appear as a leaf. This is because we included it in themeta_fields argument tojax.tree_util.register_dataclass(), indicating that it should be treated as metadata/auxiliary data, just likeaux_data inRegisteredSpecial above. Now instances ofMyDataclassContainer can be passed into JIT-ed functions, andname will be treated as static (seeMarking arguments as static for more information on static args):

@jax.jitdeff(x:MyDataclassContainer|MyOtherContainer):returnx.a+x.b# Works fine! `mdc.name` is static.mdc=MyDataclassContainer('mdc',1,2,3)y=f(mdc)

Contrast this withMyOtherContainer, theNamedTuple subclass. Since thename field is a pytree leaf, JIT expects it to be convertible tojax.Array, and the following raises an error:

moc=MyOtherContainer('moc',1,2,3)y=f(moc)
---------------------------------------------------------------------------TypeErrorTraceback (most recent call last)CellIn[9],line21moc=MyOtherContainer('moc',1,2,3)---->2y=f(moc)[...skippinghidden5frame]File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/pjit.py:678, in_infer_input_type(fun, dbg_fn, explicit_args)676dbg=dbg_fn()677arg_description=f"path{dbg.arg_names[i]ifdbg.arg_namesisnotNoneelse'unknown'}"# pytype: disable=name-error-->678raiseTypeError(679f"Error interpreting argument to{fun} as an abstract array."680f" The problematic value is of type{type(x)} and was passed to"# pytype: disable=name-error681f" the function at{arg_description}.\n"682"This typically means that a jit-wrapped function was called with a non-array"683" argument, and this argument was not marked as static using the"684" static_argnums or static_argnames parameters of jax.jit."685)fromNone686ifconfig.mutable_array_checks.value:687check_no_aliased_ref_args(dbg_fn,avals,explicit_args)TypeError: Error interpreting argument to <function f at 0x71ad3db47880> as an abstract array. The problematic value is of type <class 'str'> and was passed to the function at path x.name.Thistypicallymeansthatajit-wrappedfunctionwascalledwithanon-arrayargument,andthisargumentwasnotmarkedasstaticusingthestatic_argnumsorstatic_argnamesparametersofjax.jit.

The whole set of functions for operating on pytrees are injax.tree_util.

Custom pytrees and initialization with unexpected values#

Another common gotcha with user-defined pytree objects is that JAX transformations occasionally initialize them with unexpected values, so that any input validation done at initialization may fail. For example:

classMyTree:def__init__(self,a):self.a=jnp.asarray(a)register_pytree_node(MyTree,lambdatree:((tree.a,),None),lambda_,args:MyTree(*args))tree=MyTree(jnp.arange(5.0))jax.vmap(lambdax:x)(tree)# Error because object() is passed to `MyTree`.
<__main__.MyTree at 0x71ad4021c920>
jax.jacobian(lambdax:x)(tree)# Error because MyTree(...) is passed to `MyTree`.
---------------------------------------------------------------------------ValueErrorTraceback (most recent call last)CellIn[11],line1---->1jax.jacobian(lambdax:x)(tree)# Error because MyTree(...) is passed to `MyTree`.File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/api.py:836, injacrev.<locals>.jacfun(*args, **kwargs)829f=lu.wrap_init(830fun,kwargs,831debug_info=debug_info(832"jacrev",fun,args,kwargs,833static_argnums=(argnums,)ifisinstance(argnums,int)elseargnums))834f_partial,dyn_args=argnums_partial(f,argnums,args,835require_static_args_hashable=False)-->836tree_map(partial(_check_input_dtype_jacrev,holomorphic,allow_int),dyn_args)837ifnothas_aux:838y,pullback=_vjp(f_partial,*dyn_args)File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/tree_util.py:361, intree_map(f, tree, is_leaf, *rest)359leaves,treedef=tree_flatten(tree,is_leaf)360all_leaves=[leaves]+[treedef.flatten_up_to(r)forrinrest]--> 361 return treedef.unflatten(f(*xs) for xs inzip(*all_leaves))Cell In[10], line 6, in<lambda>(_, args)2def__init__(self,a):3self.a=jnp.asarray(a)5register_pytree_node(MyTree,lambdatree:((tree.a,),None),---->6lambda_,args:MyTree(*args))8tree=MyTree(jnp.arange(5.0))10jax.vmap(lambdax:x)(tree)# Error because object() is passed to `MyTree`.Cell In[10], line 3, inMyTree.__init__(self, a)2def__init__(self,a):---->3self.a=jnp.asarray(a)File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/numpy/array_constructors.py:401, inasarray(a, dtype, order, copy, device, out_sharding)399ifdtypeisnotNone:400dtype=dtypes.check_and_canonicalize_user_dtype(dtype,"asarray")-->401returnarray(a,dtype=dtype,copy=bool(copy),order=order,device=device,402out_sharding=out_sharding)File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/numpy/array_constructors.py:233, inarray(object, dtype, copy, order, ndmin, device, out_sharding)230leaves,treedef=tree_util.tree_flatten(231object,is_leaf=lambdax:notisinstance(x,(list,tuple)))232ifany(leafisNoneforleafinleaves):-->233raiseValueError("None is not a valid value for jnp.array")234leaves=[235leaf236if(leaf_jax_array:=getattr(leaf,"__jax_array__",None))isNone237elseleaf_jax_array()238forleafinleaves239]240ifdtypeisNone:241# Use lattice_result_type rather than result_type to avoid canonicalization.242# Otherwise, weakly-typed inputs would have their dtypes canonicalized.ValueError: None is not a valid value for jnp.array
  • In the first case withjax.vmap(...)(tree), JAX’s internals use arrays ofobject() values to infer the structure of the tree

  • In the second case withjax.jacobian(...)(tree), the Jacobian of a function mapping a tree to a tree is defined as a tree of trees.

Potential solution 1:

  • The__init__ and__new__ methods of custom pytree classes should generally avoid doing any array conversion or other input validation, or else anticipate and handle these special cases. For example:

classMyTree:def__init__(self,a):ifnot(type(a)isobjectoraisNoneorisinstance(a,MyTree)):a=jnp.asarray(a)self.a=a

Potential solution 2:

  • Structure your customtree_unflatten function so that it avoids calling__init__. If you choose this route, make sure that yourtree_unflatten function stays in sync with__init__ if and when the code is updated. Example:

deftree_unflatten(aux_data,children):delaux_data# Unused in this class.obj=object.__new__(MyTree)obj.a=children[0]returnobj

Internal pytree handling#

JAX flattens pytrees into lists of leaves at theapi.py boundary (and alsoin control flow primitives). This keeps downstream JAX internals simpler:transformations likegrad(),jit(), andvmap()can handle user functions that accept and return the myriad different Pythoncontainers, while all the other parts of the system can operate on functionsthat only take (multiple) array arguments and always return a flat list of arrays.

When JAX flattens a pytree it will produce a list of leaves and atreedefobject that encodes the structure of the original value. Thetreedef canthen be used to construct a matching structured value after transforming theleaves. Pytrees are tree-like, rather than DAG-like or graph-like, in that wehandle them assuming referential transparency and that they can’t containreference cycles.

Here is a simple example:

fromjax.tree_utilimporttree_flatten,tree_unflattenimportjax.numpyasjnp# The structured value to be transformedvalue_structured=[1.,(2.,3.)]# The leaves in value_flat correspond to the `*` markers in value_treevalue_flat,value_tree=tree_flatten(value_structured)print(f"{value_flat=}\n{value_tree=}")# Transform the flat value list using an element-wise numeric transformertransformed_flat=list(map(lambdav:v*2.,value_flat))print(f"{transformed_flat=}")# Reconstruct the structured output, using the originaltransformed_structured=tree_unflatten(value_tree,transformed_flat)print(f"{transformed_structured=}")
value_flat=[1.0, 2.0, 3.0]value_tree=PyTreeDef([*, (*, *)])transformed_flat=[2.0, 4.0, 6.0]transformed_structured=[2.0, (4.0, 6.0)]

By default, pytree containers can be lists, tuples, dicts, namedtuple, None,OrderedDict. Other types of values, including numeric and ndarray values, aretreated as leaves:

fromcollectionsimportnamedtuplePoint=namedtuple('Point',['x','y'])example_containers=[(1.,[2.,3.]),(1.,{'b':2.,'a':3.}),1.,None,jnp.zeros(2),Point(1.,2.)]defshow_example(structured):flat,tree=tree_flatten(structured)unflattened=tree_unflatten(tree,flat)print(f"{structured=}\n{flat=}\n{tree=}\n{unflattened=}")forstructuredinexample_containers:show_example(structured)
structured=(1.0, [2.0, 3.0])  flat=[1.0, 2.0, 3.0]  tree=PyTreeDef((*, [*, *]))  unflattened=(1.0, [2.0, 3.0])structured=(1.0, {'b': 2.0, 'a': 3.0})  flat=[1.0, 3.0, 2.0]  tree=PyTreeDef((*, {'a': *, 'b': *}))  unflattened=(1.0, {'a': 3.0, 'b': 2.0})structured=1.0  flat=[1.0]  tree=PyTreeDef(*)  unflattened=1.0structured=None  flat=[]  tree=PyTreeDef(None)  unflattened=Nonestructured=Array([0., 0.], dtype=float32)  flat=[Array([0., 0.], dtype=float32)]  tree=PyTreeDef(*)  unflattened=Array([0., 0.], dtype=float32)structured=Point(x=1.0, y=2.0)  flat=[1.0, 2.0]  tree=PyTreeDef(CustomNode(namedtuple[Point], [*, *]))  unflattened=Point(x=1.0, y=2.0)

[8]ページ先頭

©2009-2025 Movatter.jp