Rate this Page

torch.func.stack_module_state#

torch.func.stack_module_state(models)params,buffers[source]#

Prepares a list of torch.nn.Modules for ensembling withvmap().

Given a list ofMnn.Modules of the same class, returns two dictionariesthat stack all of their parameters and buffers together, indexed by name.The stacked parameters are optimizable (i.e. they are new leaf nodes in theautograd history that are unrelated to the original parameters and can bepassed directly to an optimizer).

Here’s an example of how to ensemble over a very simple model:

num_models=5batch_size=64in_features,out_features=3,3models=[torch.nn.Linear(in_features,out_features)foriinrange(num_models)]data=torch.randn(batch_size,3)defwrapper(params,buffers,data):returntorch.func.functional_call(models[0],(params,buffers),data)params,buffers=stack_module_state(models)output=vmap(wrapper,(0,0,None))(params,buffers,data)assertoutput.shape==(num_models,batch_size,out_features)

When there’s submodules, this follows state dict naming conventions

importtorch.nnasnnclassFoo(nn.Module):def__init__(self,in_features,out_features):super().__init__()hidden=4self.l1=nn.Linear(in_features,hidden)self.l2=nn.Linear(hidden,out_features)defforward(self,x):returnself.l2(self.l1(x))num_models=5in_features,out_features=3,3models=[Foo(in_features,out_features)foriinrange(num_models)]params,buffers=stack_module_state(models)print(list(params.keys()))# "l1.weight", "l1.bias", "l2.weight", "l2.bias"

Warning

All of the modules being stacked together must be the same (except forthe values of their parameters/buffers). For example, they should be in thesame mode (training vs eval).

Return type

tuple[dict[str,Any],dict[str,Any]]