Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up

OpTree: Optimized PyTree Utilities

License

NotificationsYou must be signed in to change notification settings

metaopt/optree

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Python 3.9+PyPIGitHub Workflow StatusGitHub Workflow StatusCodecovDocumentation StatusDownloadsGitHub Repo Stars

Optimized PyTree Utilities.


Table of Contents


Installation

Install from PyPI (PyPI /Status):

pip3 install --upgrade optree

Install from conda-forge (conda-forge):

conda install conda-forge::optree

Install the latest version from GitHub:

pip3 install git+https://github.com/metaopt/optree.git#egg=optree

Or, clone this repo and install manually:

git clone --depth=1 https://github.com/metaopt/optree.gitcd optreepip3 install.

The following options are available while building the Python C extension from the source:

export CMAKE_COMMAND="/path/to/custom/cmake"export CMAKE_BUILD_TYPE="Debug"export CMAKE_CXX_STANDARD="20"# C++17 is tested on Linux/macOS (C++20 is required on Windows)export OPTREE_CXX_WERROR="OFF"export _GLIBCXX_USE_CXX11_ABI="1"export pybind11_DIR="/path/to/custom/pybind11"pip3 install.

Compiling from the source requires Python 3.9+, a compiler (gcc /clang /icc /cl.exe) that supports C++20 and acmake installation.


PyTrees

A PyTree is a recursive structure that can be an arbitrarily nested Python container (e.g.,tuple,list,dict,OrderedDict,NamedTuple, etc.) or an opaque Python object.The key concepts of tree operations are tree flattening and its inverse (tree unflattening).Additional tree operations can be performed based on these two basic functions (e.g.,tree_map = tree_unflatten ∘ map ∘ tree_flatten).

Tree flattening is traversing the entire tree in a left-to-right depth-first manner and returning the leaves of the tree in a deterministic order.

>>>tree= {'b': (2, [3,4]),'a':1,'c':5,'d':6}>>>optree.tree_flatten(tree)([1,2,3,4,5,6],PyTreeSpec({'a':*,'b': (*, [*,*]),'c':*,'d':*}))>>>optree.tree_flatten(1)([1],PyTreeSpec(*))>>>optree.tree_flatten(None)([],PyTreeSpec(None))>>>optree.tree_map(lambdax:x**2,tree){'b': (4, [9,16]),'a':1,'c':25,'d':36}

This usually implies that the equal pytrees return equal lists of leaves and the same tree structure.See also sectionKey Ordering for Dictionaries.

>>> {'a': [1,2],'b': [3]}== {'b': [3],'a': [1,2]}True>>>optree.tree_leaves({'a': [1,2],'b': [3]})==optree.tree_leaves({'b': [3],'a': [1,2]})True>>>optree.tree_structure({'a': [1,2],'b': [3]})==optree.tree_structure({'b': [3],'a': [1,2]})True>>>optree.tree_map(lambdax:x**2, {'a': [1,2],'b': [3]}){'a': [1,4],'b': [9]}>>>optree.tree_map(lambdax:x**2, {'b': [3],'a': [1,2]}){'b': [9],'a': [1,4]}

Tip

Since OpTree v0.14.1, a new namespaceoptree.pytree is introduced as aliases foroptree.tree_* functions. The following examples are equivalent to the above:

>>>importoptree.pytreeaspt>>>tree= {'b': (2, [3,4]),'a':1,'c':5,'d':6}>>>pt.flatten(tree)([1,2,3,4,5,6],PyTreeSpec({'a':*,'b': (*, [*,*]),'c':*,'d':*}))>>>pt.flatten(1)([1],PyTreeSpec(*))>>>pt.flatten(None)([],PyTreeSpec(None))>>>pt.map(lambdax:x**2,tree){'b': (4, [9,16]),'a':1,'c':25,'d':36}>>>pt.map(lambdax:x**2, {'a': [1,2],'b': [3]}){'a': [1,4],'b': [9]}>>>pt.map(lambdax:x**2, {'b': [3],'a': [1,2]}){'b': [9],'a': [1,4]}

Tree Nodes and Leaves

A tree is a collection of non-leaf nodes and leaf nodes, where the leaf nodes have no children to flatten.optree.tree_flatten(...) will flatten the tree and return a list of leaf nodes while the non-leaf nodes will store in the tree specification.

Built-in PyTree Node Types

OpTree out-of-box supports the following Python container types in the registry:

which are considered non-leaf nodes in the tree.Python objects that the type is not registered will be treated as leaf nodes.The registry lookup uses theis operator to determine whether the type is matched.So subclasses will need to explicitly register in the registry, otherwise, an object of that type will be considered a leaf.TheNoneType is a special case discussed in sectionNone is non-leaf Node vs.None is Leaf.

Registering a Container-like Custom Type as Non-leaf Nodes

A container-like Python type can be registered in the type registry with a pair of functions that specify:

  • flatten_func(container) -> (children, metadata, entries): convert an instance of the container type to a(children, metadata, entries) triple, wherechildren is an iterable of subtrees andentries is an iterable of path entries of the container (e.g., indices or keys).
  • unflatten_func(metadata, children) -> container: convert such a pair back to an instance of the container type.

Themetadata is some necessary data apart from the children to reconstruct the container, e.g., the keys of the dictionary (the children are values).

Theentries can be omitted (only returns a pair) or is optional to implement (returnsNone). If so, userange(len(children)) (i.e., flat indices) as path entries of the current node. The signature for the flatten function can be one of the following:

  • flatten_func(container) -> (children, metadata, entries)
  • flatten_func(container) -> (children, metadata, None)
  • flatten_func(container) -> (children, metadata)

The following examples show how to register custom types and utilize them fortree_flatten andtree_map. Please refer to sectionNotes about the PyTree Type Registry for more information.

# Registry a Python type with lambda functionsoptree.register_pytree_node(set,# (set) -> (children, metadata, None)lambdas: (sorted(s),None,None),# (metadata, children) -> (set)lambda_,children:set(children),namespace='set',)# Register a Python type into a namespaceimporttorchclassTorch2NumpyEntry(optree.PyTreeEntry):def__call__(self,obj):assertself.entry==0returnobj.cpu().detach().numpy()defcodify(self,node=''):assertself.entry==0returnf'{node}.cpu().detach().numpy()'optree.register_pytree_node(torch.Tensor,# (tensor) -> (children, metadata)flatten_func=lambdatensor: (        (tensor.cpu().detach().numpy(),),        {'dtype':tensor.dtype,'device':tensor.device,'requires_grad':tensor.requires_grad},    ),# (metadata, children) -> tensorunflatten_func=lambdametadata,children:torch.tensor(children[0],**metadata),path_entry_type=Torch2NumpyEntry,namespace='torch2numpy',)
>>>tree= {'weight':torch.ones(size=(1,2)).cuda(),'bias':torch.zeros(size=(2,))}>>>tree{'weight':tensor([[1.,1.]],device='cuda:0'),'bias':tensor([0.,0.])}# Flatten without specifying the namespace>>>optree.tree_flatten(tree)# `torch.Tensor`s are leaf nodes([tensor([0.,0.]),tensor([[1.,1.]],device='cuda:0')],PyTreeSpec({'bias':*,'weight':*}))# Flatten with the namespace>>>leaves,treespec=optree.tree_flatten(tree,namespace='torch2numpy')>>>leaves,treespec(    [array([0.,0.],dtype=float32),array([[1.,1.]],dtype=float32)],PyTreeSpec(        {'bias':CustomTreeNode(Tensor[{'dtype':torch.float32,'device':device(type='cpu'),'requires_grad':False}], [*]),'weight':CustomTreeNode(Tensor[{'dtype':torch.float32,'device':device(type='cuda',index=0),'requires_grad':False}], [*])        },namespace='torch2numpy'    ))# `entries` are not defined and use `range(len(children))`>>>optree.tree_paths(tree,namespace='torch2numpy')[('bias',0), ('weight',0)]# Custom path entry type defines the pytree access behavior>>>optree.tree_accessors(tree,namespace='torch2numpy')[PyTreeAccessor(*['bias'].cpu().detach().numpy(), (MappingEntry(key='bias',type=<class'dict'>),Torch2NumpyEntry(entry=0,type=<class'torch.Tensor'>))),PyTreeAccessor(*['weight'].cpu().detach().numpy(), (MappingEntry(key='weight',type=<class'dict'>),Torch2NumpyEntry(entry=0,type=<class'torch.Tensor'>)))]# Unflatten back to a copy of the original object>>>optree.tree_unflatten(treespec,leaves){'weight':tensor([[1.,1.]],device='cuda:0'),'bias':tensor([0.,0.])}

Users can also extend the pytree registry by decorating the custom class and defining an instance methodtree_flatten and a class methodtree_unflatten.

fromcollectionsimportUserDict@optree.register_pytree_node_class(namespace='mydict')classMyDict(UserDict):TREE_PATH_ENTRY_TYPE=optree.MappingEntry# used by accessor APIsdeftree_flatten(self):# -> (children, metadata, entries)reversed_keys=sorted(self.keys(),reverse=True)return (            [self[key]forkeyinreversed_keys],# childrenreversed_keys,# metadatareversed_keys,# entries        )@classmethoddeftree_unflatten(cls,metadata,children):returncls(zip(metadata,children))
>>>tree=MyDict(b=4,a=(2,3),c=MyDict({'d':5,'f':6}))# Flatten without specifying the namespace>>>optree.tree_flatten_with_path(tree)# `MyDict`s are leaf nodes(    [()],    [MyDict(b=4,a=(2,3),c=MyDict({'d':5,'f':6}))],PyTreeSpec(*))# Flatten with the namespace>>>optree.tree_flatten_with_path(tree,namespace='mydict')(    [('c','f'), ('c','d'), ('b',), ('a',0), ('a',1)],    [6,5,4,2,3],PyTreeSpec(CustomTreeNode(MyDict[['c','b','a']], [CustomTreeNode(MyDict[['f','d']], [*,*]),*, (*,*)]),namespace='mydict'    ))>>>optree.tree_flatten_with_accessor(tree,namespace='mydict')(    [PyTreeAccessor(*['c']['f'], (MappingEntry(key='c',type=<class'MyDict'>),MappingEntry(key='f',type=<class'MyDict'>))),PyTreeAccessor(*['c']['d'], (MappingEntry(key='c',type=<class'MyDict'>),MappingEntry(key='d',type=<class'MyDict'>))),PyTreeAccessor(*['b'], (MappingEntry(key='b',type=<class'MyDict'>),)),PyTreeAccessor(*['a'][0], (MappingEntry(key='a',type=<class'MyDict'>),SequenceEntry(index=0,type=<class'tuple'>))),PyTreeAccessor(*['a'][1], (MappingEntry(key='a',type=<class'MyDict'>),SequenceEntry(index=1,type=<class'tuple'>)))    ],    [6,5,4,2,3],PyTreeSpec(CustomTreeNode(MyDict[['c','b','a']], [CustomTreeNode(MyDict[['f','d']], [*,*]),*, (*,*)]),namespace='mydict'    ))

Notes about the PyTree Type Registry

There are several key attributes of the pytree type registry:

  1. The type registry is per-interpreter-dependent. This means registering a custom type in the registry affects all modules that use OpTree.

Warning

For safety reasons, anamespace must be specified while registering a custom type. It isused to isolate the behavior of flattening and unflattening a pytree node type. This is toprevent accidental collisions between different libraries that may register the same type.

  1. The elements in the type registry are immutable. Users can neither register the same type twice in the same namespace (i.e., update the type registry), nor remove a type from the type registry. To update the behavior of an already registered type, simply register it again with anothernamespace.

  2. Users cannot modify the behavior of already registered built-in types listed inBuilt-in PyTree Node Types, such as key order sorting fordict andcollections.defaultdict.

  3. Inherited subclasses are not implicitly registered. The registry lookup usestype(obj) is registered_type rather thanisinstance(obj, registered_type). Users need to register the subclasses explicitly. To register all subclasses, it is easy to implement withmetaclass or__init_subclass__, for example:

    fromcollectionsimportUserDict@optree.register_pytree_node_class(namespace='mydict')classMyDict(UserDict):TREE_PATH_ENTRY_TYPE=optree.MappingEntry# used by accessor APIsdef__init_subclass__(cls):# define this in the base classsuper().__init_subclass__()# Register a subclass to namespace 'mydict'optree.register_pytree_node_class(cls,namespace='mydict')deftree_flatten(self):# -> (children, metadata, entries)reversed_keys=sorted(self.keys(),reverse=True)return (            [self[key]forkeyinreversed_keys],# childrenreversed_keys,# metadatareversed_keys,# entries        )@classmethoddeftree_unflatten(cls,metadata,children):returncls(zip(metadata,children))# Subclasses will be automatically registered in namespace 'mydict'classMyAnotherDict(MyDict):pass
    >>>tree=MyDict(b=4,a=(2,3),c=MyAnotherDict({'d':5,'f':6}))>>>optree.tree_flatten_with_path(tree,namespace='mydict')(    [('c','f'), ('c','d'), ('b',), ('a',0), ('a',1)],    [6,5,4,2,3],PyTreeSpec(CustomTreeNode(MyDict[['c','b','a']], [CustomTreeNode(MyAnotherDict[['f','d']], [*,*]),*, (*,*)]),namespace='mydict'    ))>>>optree.tree_accessors(tree,namespace='mydict')[PyTreeAccessor(*['c']['f'], (MappingEntry(key='c',type=<class'MyDict'>),MappingEntry(key='f',type=<class'MyAnotherDict'>))),PyTreeAccessor(*['c']['d'], (MappingEntry(key='c',type=<class'MyDict'>),MappingEntry(key='d',type=<class'MyAnotherDict'>))),PyTreeAccessor(*['b'], (MappingEntry(key='b',type=<class'MyDict'>),)),PyTreeAccessor(*['a'][0], (MappingEntry(key='a',type=<class'MyDict'>),SequenceEntry(index=0,type=<class'tuple'>))),PyTreeAccessor(*['a'][1], (MappingEntry(key='a',type=<class'MyDict'>),SequenceEntry(index=1,type=<class'tuple'>)))]
  4. Be careful about the potential infinite recursion of the custom flatten function. The returnedchildren from the custom flatten function are considered subtrees. They will be further flattened recursively. Thechildren can have the same type as the current node. Users must design their termination condition carefully.

    importnumpyasnpimporttorchoptree.register_pytree_node(np.ndarray,# Children are nest lists of Python objectslambdaarray: (np.atleast_1d(array).tolist(),array.ndim==0),lambdascalar,rows:np.asarray(rows)ifnotscalarelsenp.asarray(rows[0]),namespace='numpy1',)optree.register_pytree_node(np.ndarray,# Children are Python objectslambdaarray: (list(array.ravel()),# list(1DArray[T]) -> List[T]dict(shape=array.shape,dtype=array.dtype)    ),lambdametadata,children:np.asarray(children,dtype=metadata['dtype']).reshape(metadata['shape']),namespace='numpy2',)optree.register_pytree_node(np.ndarray,# Returns a list of `np.ndarray`s without termination conditionlambdaarray: ([array.ravel()],array.dtype),lambdashape,children:children[0].reshape(shape),namespace='numpy3',)optree.register_pytree_node(torch.Tensor,# Children are nest lists of Python objectslambdatensor: (torch.atleast_1d(tensor).tolist(),tensor.ndim==0),lambdascalar,rows:torch.tensor(rows)ifnotscalarelsetorch.tensor(rows[0])),namespace='torch1',)optree.register_pytree_node(torch.Tensor,# Returns a list of `torch.Tensor`s without termination conditionlambdatensor: (list(tensor.view(-1)),# list(1DTensor[T]) -> List[0DTensor[T]] (STILL TENSORS!)tensor.shape    ),lambdashape,children:torch.stack(children).reshape(shape),namespace='torch2',)
    >>>optree.tree_flatten(np.arange(9).reshape(3,3),namespace='numpy1')(    [0,1,2,3,4,5,6,7,8],PyTreeSpec(CustomTreeNode(ndarray[False], [[*,*,*], [*,*,*], [*,*,*]]),namespace='numpy1'    ))# Implicitly casts `float`s to `np.float64`>>>optree.tree_map(lambdax:x+1.5,np.arange(9).reshape(3,3),namespace='numpy1')array([[1.5,2.5,3.5],       [4.5,5.5,6.5],       [7.5,8.5,9.5]])>>>optree.tree_flatten(np.arange(9).reshape(3,3),namespace='numpy2')(    [0,1,2,3,4,5,6,7,8],PyTreeSpec(CustomTreeNode(ndarray[{'shape': (3,3),'dtype':dtype('int64')}], [*,*,*,*,*,*,*,*,*]),namespace='numpy2'    ))# Explicitly casts `float`s to `np.int64`>>>optree.tree_map(lambdax:x+1.5,np.arange(9).reshape(3,3),namespace='numpy2')array([[1,2,3],       [4,5,6],       [7,8,9]])# Children are also `np.ndarray`s, recurse without termination condition.>>>optree.tree_flatten(np.arange(9).reshape(3,3),namespace='numpy3')Traceback (mostrecentcalllast):    ...RecursionError:Maximumrecursiondepthexceededduringflatteningthetree.>>>optree.tree_flatten(torch.arange(9).reshape(3,3),namespace='torch1')(    [0,1,2,3,4,5,6,7,8],PyTreeSpec(CustomTreeNode(Tensor[False], [[*,*,*], [*,*,*], [*,*,*]]),namespace='torch1'    ))# Implicitly casts `float`s to `torch.float32`>>>optree.tree_map(lambdax:x+1.5,torch.arange(9).reshape(3,3),namespace='torch1')tensor([[1.5000,2.5000,3.5000],        [4.5000,5.5000,6.5000],        [7.5000,8.5000,9.5000]])# Children are also `torch.Tensor`s, recurse without termination condition.>>>optree.tree_flatten(torch.arange(9).reshape(3,3),namespace='torch2')Traceback (mostrecentcalllast):    ...RecursionError:Maximumrecursiondepthexceededduringflatteningthetree.

None is Non-leaf Node vs.None is Leaf

TheNone object is a special object in the Python language.It serves some of the same purposes asnull (a pointer does not point to anything) in other programming languages, which denotes a variable is empty or marks default parameters.However, theNone object is a singleton object rather than a pointer.It may also serve as a sentinel value.In addition, if a function has returned without any return value or the return statement is omitted, the function will also implicitly return theNone object.

By default, theNone object is considered a non-leaf node in the tree with arity 0, i.e.,a non-leaf node that has no children.This is like the behavior of an empty tuple.While flattening a tree, it will remain in the tree structure definitions rather than in the leaves list.

>>>tree= {'b': (2, [3,4]),'a':1,'c':None,'d':5}>>>optree.tree_flatten(tree)([1,2,3,4,5],PyTreeSpec({'a':*,'b': (*, [*,*]),'c':None,'d':*}))>>>optree.tree_flatten(tree,none_is_leaf=True)([1,2,3,4,None,5],PyTreeSpec({'a':*,'b': (*, [*,*]),'c':*,'d':*},NoneIsLeaf))>>>optree.tree_flatten(1)([1],PyTreeSpec(*))>>>optree.tree_flatten(None)([],PyTreeSpec(None))>>>optree.tree_flatten(None,none_is_leaf=True)([None],PyTreeSpec(*,NoneIsLeaf))

OpTree provides a keyword argumentnone_is_leaf to determine whether to consider theNone object as a leaf, like other opaque objects.Ifnone_is_leaf=True, theNone object will be placed in the leaves list.Otherwise, theNone object will remain in the tree specification (structure).

>>>importtorch>>>linear=torch.nn.Linear(in_features=3,out_features=2,bias=False)>>>linear._parameters# a container has NoneOrderedDict({'weight':Parametercontaining:tensor([[-0.6677,0.5209,0.3295],                      [-0.4876,-0.3142,0.1785]],requires_grad=True),'bias':None})>>>optree.tree_map(torch.zeros_like,linear._parameters)OrderedDict({'weight':tensor([[0.,0.,0.],                      [0.,0.,0.]]),'bias':None})>>>optree.tree_map(torch.zeros_like,linear._parameters,none_is_leaf=True)Traceback (mostrecentcalllast):    ...TypeError:zeros_like():argument'input' (position1)mustbeTensor,notNoneType>>>optree.tree_map(lambdat:torch.zeros_like(t)iftisnotNoneelse0,linear._parameters,none_is_leaf=True)OrderedDict({'weight':tensor([[0.,0.,0.],                      [0.,0.,0.]]),'bias':0})

Key Ordering for Dictionaries

The built-in Python dictionary (i.e.,builtins.dict) is an unordered mapping that holds the keys and values.The leaves of a dictionary are the values. Although since Python 3.6, the built-in dictionary is insertion ordered (PEP 468).The dictionary equality operator (==) does not check for key ordering.To ensurereferential transparency that "equaldict" implies "equal ordering of leaves", the order of values of the dictionary is sorted by the keys.This behavior is also applied tocollections.defaultdict.

>>>optree.tree_flatten({'a': [1,2],'b': [3]})([1,2,3],PyTreeSpec({'a': [*,*],'b': [*]}))>>>optree.tree_flatten({'b': [3],'a': [1,2]})([1,2,3],PyTreeSpec({'a': [*,*],'b': [*]}))

If users want to keep the values in the insertion order in pytree traversal, they should usecollections.OrderedDict, which will take the order of keys under consideration:

>>>OrderedDict([('a', [1,2]), ('b', [3])])==OrderedDict([('b', [3]), ('a', [1,2])])False>>>optree.tree_flatten(OrderedDict([('a', [1,2]), ('b', [3])]))([1,2,3],PyTreeSpec(OrderedDict({'a': [*,*],'b': [*]})))>>>optree.tree_flatten(OrderedDict([('b', [3]), ('a', [1,2])]))([3,1,2],PyTreeSpec(OrderedDict({'b': [*],'a': [*,*]})))

Since OpTree v0.9.0, the key order of the reconstructed output dictionaries fromtree_unflatten is guaranteed to be consistent with the key order of the input dictionaries intree_flatten.

>>>leaves,treespec=optree.tree_flatten({'b': [3],'a': [1,2]})>>>leaves,treespec([1,2,3],PyTreeSpec({'a': [*,*],'b': [*]}))>>>optree.tree_unflatten(treespec,leaves){'b': [3],'a': [1,2]}>>>optree.tree_map(lambdax:x, {'b': [3],'a': [1,2]}){'b': [3],'a': [1,2]}>>>optree.tree_map(lambdax:x+1, {'b': [3],'a': [1,2]}){'b': [4],'a': [2,3]}

This property is also preserved during serialization/deserialization.

>>>leaves,treespec=optree.tree_flatten({'b': [3],'a': [1,2]})>>>leaves,treespec([1,2,3],PyTreeSpec({'a': [*,*],'b': [*]}))>>>restored_treespec=pickle.loads(pickle.dumps(treespec))>>>optree.tree_unflatten(treespec,leaves){'b': [3],'a': [1,2]}>>>optree.tree_unflatten(restored_treespec,leaves){'b': [3],'a': [1,2]}

Note

Note that there are no restrictions on thedict to require the keys to be comparable (sortable).There can be multiple types of keys in the dictionary.The keys are sorted in ascending order bykey=lambda k: k first if capable otherwise fallback tokey=lambda k: (f'{k.__class__.__module__}.{k.__class__.__qualname__}', k). This handles most cases.

>>>sorted({1:2,1.5:1}.keys())[1,1.5]>>>sorted({'a':3,1:2,1.5:1}.keys())Traceback (mostrecentcalllast):    ...TypeError:'<'notsupportedbetweeninstancesof'int'and'str'>>>sorted({'a':3,1:2,1.5:1}.keys(),key=lambdak: (f'{k.__class__.__module__}.{k.__class__.__qualname__}',k))[1.5,1,'a']

Benchmark

We benchmark the performance of:

  • tree flatten
  • tree unflatten
  • tree copy (i.e.,unflatten(flatten(...)))
  • tree map

compared with the following libraries:

Average Time Cost (↓)OpTree (v0.9.0)JAX XLA (v0.4.6)PyTorch (v2.0.0)DM-Tree (v0.1.8)
Tree Flattenx1.002.3322.051.12
Tree UnFlattenx1.002.694.2816.23
Tree Flatten with Pathx1.0016.16Not Supported27.59
Tree Copyx1.002.569.9711.02
Tree Mapx1.002.569.5810.62
Tree Map (nargs)x1.002.89Not Supported31.33
Tree Map with Pathx1.007.23Not Supported19.66
Tree Map with Path (nargs)x1.006.56Not Supported29.61

All results are reported on a workstation with an AMD Ryzen 9 5950X CPU @ 4.45GHz in an isolated virtual environment with Python 3.10.9.Run with the following commands:

conda create --name optree-benchmark anaconda::python=3.10 --yes --no-default-packagesconda activate optree-benchmarkpython3 -m pip install --editable'.[benchmark]' --extra-index-url https://download.pytorch.org/whl/cpupython3 benchmark.py --number=10000 --repeat=5

The test inputs are nested containers (i.e., pytrees) extracted fromtorch.nn.Module objects.They are:

tiny_mlp=nn.Sequential(nn.Linear(1,1,bias=True),nn.BatchNorm1d(1,affine=True,track_running_stats=True),nn.ReLU(),nn.Linear(1,1,bias=False),nn.Sigmoid(),)

and AlexNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, VisionTransformerH14 (ViT-H/14), and SwinTransformerB (Swin-B) fromtorchvsion.Please refer tobenchmark.py for more details.

Tree Flatten

ModuleNodesOpTree (μs)JAX XLA (μs)PyTorch (μs)DM-Tree (μs)Speedup (J / O)Speedup (P / O)Speedup (D / O)
TinyMLP5329.7071.06583.6631.322.3919.651.05
AlexNet188103.92262.562304.36119.612.5322.171.15
ResNet18698368.06852.698440.31420.432.3222.931.14
ResNet341242644.961461.5514498.81712.812.2722.481.11
ResNet501702919.952080.5820995.961006.422.2622.821.09
ResNet10133171806.363996.9040314.121955.482.2122.321.08
ResNet15249322656.925812.3857775.532826.922.1921.751.06
ViT-H/1434201863.504418.2441334.642128.712.3722.181.14
Swin-B28811631.063944.1336131.542032.772.4222.151.25
Average2.3322.051.12

Tree UnFlatten

ModuleNodesOpTree (μs)JAX XLA (μs)PyTorch (μs)DM-Tree (μs)Speedup (J / O)Speedup (P / O)Speedup (D / O)
TinyMLP5355.13152.07231.94940.112.764.2117.05
AlexNet188226.29678.29972.904195.043.004.3018.54
ResNet18698766.541953.263137.8612049.882.554.0915.72
ResNet3412421309.223526.125759.1620966.752.694.4016.01
ResNet5017021914.965002.838369.4329597.102.614.3715.46
ResNet10133173672.619633.2915683.1657240.202.624.2715.59
ResNet15249325407.5813970.8823074.6882072.542.584.2715.18
ViT-H/1434204013.1811146.3117633.0766723.582.784.3916.63
Swin-B28813595.349505.3115054.8857310.032.644.1915.94
Average2.694.2816.23

Tree Flatten with Path

ModuleNodesOpTree (μs)JAX XLA (μs)PyTorch (μs)DM-Tree (μs)Speedup (J / O)Speedup (P / O)Speedup (D / O)
TinyMLP5336.49543.67N/A919.1314.90N/A25.19
AlexNet188115.442185.21N/A3752.1118.93N/A32.50
ResNet18698431.847106.55N/A12286.7016.46N/A28.45
ResNet341242845.6113431.99N/A22860.4815.88N/A27.03
ResNet5017021166.2718426.52N/A31225.0515.80N/A26.77
ResNet10133172312.7734770.49N/A59346.8615.03N/A25.66
ResNet15249323304.7450557.25N/A85847.9115.30N/A25.98
ViT-H/1434202235.2537473.53N/A64105.2416.76N/A28.68
Swin-B28811970.2532205.83N/A55177.5016.35N/A28.01
Average16.16N/A27.59

Tree Copy

ModuleNodesOpTree (μs)JAX XLA (μs)PyTorch (μs)DM-Tree (μs)Speedup (J / O)Speedup (P / O)Speedup (D / O)
TinyMLP5389.81232.26845.20981.482.599.4110.93
AlexNet188334.58959.323360.464316.052.8710.0412.90
ResNet186981128.112840.7111471.0712297.072.5210.1710.90
ResNet3412422160.575333.1020563.0621901.912.479.5210.14
ResNet5017022746.846823.8829705.9928927.882.4810.8110.53
ResNet10133175762.0513481.4556968.7860115.932.349.8910.43
ResNet15249328151.2120805.6181024.0684079.572.559.9410.31
ViT-H/1434205963.6115665.9159813.5268377.822.6310.0311.47
Swin-B28815401.5914255.3353361.7762317.072.649.8811.54
Average2.569.9711.02

Tree Map

ModuleNodesOpTree (μs)JAX XLA (μs)PyTorch (μs)DM-Tree (μs)Speedup (J / O)Speedup (P / O)Speedup (D / O)
TinyMLP5395.13243.86867.341026.992.569.1210.80
AlexNet188348.44987.573398.324354.812.839.7512.50
ResNet186981190.622982.6611719.9412559.012.519.8410.55
ResNet3412422205.875417.6020935.7222308.512.469.4910.11
ResNet5017023128.487579.5530372.7131638.672.429.7110.11
ResNet10133176173.0514846.5759167.8560245.422.419.589.76
ResNet15249328641.2222000.7484018.6586182.212.559.729.97
ViT-H/1434206211.7917077.4959790.2569763.862.759.6311.23
Swin-B28815673.6614339.6953309.1759764.612.539.4010.53
Average2.569.5810.62

Tree Map (nargs)

ModuleNodesOpTree (μs)JAX XLA (μs)PyTorch (μs)DM-Tree (μs)Speedup (J / O)Speedup (P / O)Speedup (D / O)
TinyMLP53137.06389.96N/A3908.772.85N/A28.52
AlexNet188467.241496.96N/A15395.133.20N/A32.95
ResNet186981603.794534.01N/A50323.762.83N/A31.38
ResNet3412422907.648435.33N/A90389.232.90N/A31.09
ResNet5017024183.7711382.51N/A121777.012.72N/A29.11
ResNet10133177721.1322247.85N/A238755.172.88N/A30.92
ResNet152493211508.0531429.39N/A360257.742.73N/A31.30
ViT-H/1434208294.2024524.86N/A270514.872.96N/A32.61
Swin-B28817074.6220854.80N/A241120.412.95N/A34.08
Average2.89N/A31.33

Tree Map with Path

ModuleNodesOpTree (μs)JAX XLA (μs)PyTorch (μs)DM-Tree (μs)Speedup (J / O)Speedup (P / O)Speedup (D / O)
TinyMLP53109.82778.30N/A2186.407.09N/A19.91
AlexNet188365.162939.36N/A8355.378.05N/A22.88
ResNet186981308.269529.58N/A25758.247.28N/A19.69
ResNet3412422527.2118084.89N/A45942.327.16N/A18.18
ResNet5017023226.0322935.53N/A61275.347.11N/A18.99
ResNet10133176663.5246878.89N/A126642.147.04N/A19.01
ResNet15249329378.1966136.44N/A176981.017.05N/A18.87
ViT-H/1434207033.6950418.37N/A142508.117.17N/A20.26
Swin-B28816078.1543173.22N/A116612.717.10N/A19.19
Average7.23N/A19.66

Tree Map with Path (nargs)

ModuleNodesOpTree (μs)JAX XLA (μs)PyTorch (μs)DM-Tree (μs)Speedup (J / O)Speedup (P / O)Speedup (D / O)
TinyMLP53146.05917.00N/A3940.616.28N/A26.98
AlexNet188489.273560.76N/A15434.717.28N/A31.55
ResNet186981712.7911171.44N/A50219.866.52N/A29.32
ResNet3412423112.8321024.58N/A95505.716.75N/A30.68
ResNet5017024220.7026600.82N/A121897.576.30N/A28.88
ResNet10133178631.3454372.37N/A236555.546.30N/A27.41
ResNet152493212710.4977643.13N/A353600.326.11N/A27.82
ViT-H/1434208753.0958712.71N/A286365.366.71N/A32.72
Swin-B28817359.2950112.23N/A228866.666.81N/A31.10
Average6.56N/A29.61

Changelog

SeeCHANGELOG.md.


License

OpTree is released under the Apache License 2.0.

OpTree is heavily based on JAX's implementation of the PyTree utility, with deep refactoring and several improvements.The original licenses can be found atJAX's Apache License 2.0 andTensorflow's Apache License 2.0.


[8]ページ先頭

©2009-2025 Movatter.jp