torch.func.stack_module_state#
- torch.func.stack_module_state(models)→params,buffers[source]#
Prepares a list of torch.nn.Modules for ensembling with
vmap().Given a list of
Mnn.Modulesof the same class, returns two dictionariesthat stack all of their parameters and buffers together, indexed by name.The stacked parameters are optimizable (i.e. they are new leaf nodes in theautograd history that are unrelated to the original parameters and can bepassed directly to an optimizer).Here’s an example of how to ensemble over a very simple model:
num_models=5batch_size=64in_features,out_features=3,3models=[torch.nn.Linear(in_features,out_features)foriinrange(num_models)]data=torch.randn(batch_size,3)defwrapper(params,buffers,data):returntorch.func.functional_call(models[0],(params,buffers),data)params,buffers=stack_module_state(models)output=vmap(wrapper,(0,0,None))(params,buffers,data)assertoutput.shape==(num_models,batch_size,out_features)
When there’s submodules, this follows state dict naming conventions
importtorch.nnasnnclassFoo(nn.Module):def__init__(self,in_features,out_features):super().__init__()hidden=4self.l1=nn.Linear(in_features,hidden)self.l2=nn.Linear(hidden,out_features)defforward(self,x):returnself.l2(self.l1(x))num_models=5in_features,out_features=3,3models=[Foo(in_features,out_features)foriinrange(num_models)]params,buffers=stack_module_state(models)print(list(params.keys()))# "l1.weight", "l1.bias", "l2.weight", "l2.bias"
Warning
All of the modules being stacked together must be the same (except forthe values of their parameters/buffers). For example, they should be in thesame mode (training vs eval).