Rate this Page

Note

Go to the endto download the full example code.

Model ensembling#

Created On: Mar 15, 2023 | Last Updated: Oct 02, 2025 | Last Verified: Nov 05, 2024

This tutorial illustrates how to vectorize model ensembling usingtorch.vmap.

What is model ensembling?#

Model ensembling combines the predictions from multiple models together.Traditionally this is done by running each model on some inputs separatelyand then combining the predictions. However, if you’re running models withthe same architecture, then it may be possible to combine them togetherusingtorch.vmap.vmap is a function transform that maps functions acrossdimensions of the input tensors. One of its use cases is eliminatingfor-loops and speeding them up through vectorization.

Let’s demonstrate how to do this using an ensemble of simple MLPs.

Note

This tutorial requires PyTorch 2.0.0 or later.

importtorchimporttorch.nnasnnimporttorch.nn.functionalasFtorch.manual_seed(0)# Here's a simple MLPclassSimpleMLP(nn.Module):def__init__(self):super(SimpleMLP,self).__init__()self.fc1=nn.Linear(784,128)self.fc2=nn.Linear(128,128)self.fc3=nn.Linear(128,10)defforward(self,x):x=x.flatten(1)x=self.fc1(x)x=F.relu(x)x=self.fc2(x)x=F.relu(x)x=self.fc3(x)returnx

Let’s generate a batch of dummy data and pretend that we’re working withan MNIST dataset. Thus, the dummy images are 28 by 28, and we have aminibatch of size 64. Furthermore, lets say we want to combine the predictionsfrom 10 different models.

device=torch.accelerator.current_accelerator()num_models=10data=torch.randn(100,64,1,28,28,device=device)targets=torch.randint(10,(6400,),device=device)models=[SimpleMLP().to(device)for_inrange(num_models)]

We have a couple of options for generating predictions. Maybe we want togive each model a different randomized minibatch of data. Alternatively,maybe we want to run the same minibatch of data through each model (e.g.if we were testing the effect of different model initializations).

Option 1: different minibatch for each model

minibatches=data[:num_models]predictions_diff_minibatch_loop=[model(minibatch)formodel,minibatchinzip(models,minibatches)]

Option 2: Same minibatch

minibatch=data[0]predictions2=[model(minibatch)formodelinmodels]

Usingvmap to vectorize the ensemble#

Let’s usevmap to speed up the for-loop. We must first prepare the modelsfor use withvmap.

First, let’s combine the states of the model together by stacking eachparameter. For example,model[i].fc1.weight has shape[784,128]; we aregoing to stack the.fc1.weight of each of the 10 models to produce a bigweight of shape[10,784,128].

PyTorch offers thetorch.func.stack_module_state convenience function to dothis.

fromtorch.funcimportstack_module_stateparams,buffers=stack_module_state(models)

Next, we need to define a function tovmap over. The function should,given parameters and buffers and inputs, run the model using thoseparameters, buffers, and inputs. We’ll usetorch.func.functional_callto help out:

fromtorch.funcimportfunctional_callimportcopy# Construct a "stateless" version of one of the models. It is "stateless" in# the sense that the parameters are meta Tensors and do not have storage.base_model=copy.deepcopy(models[0])base_model=base_model.to('meta')deffmodel(params,buffers,x):returnfunctional_call(base_model,(params,buffers),(x,))

Option 1: get predictions using a different minibatch for each model.

By default,vmap maps a function across the first dimension of all inputs tothe passed-in function. After usingstack_module_state, each oftheparams and buffers have an additional dimension of size ‘num_models’ atthe front, and minibatches has a dimension of size ‘num_models’.

print([p.size(0)forpinparams.values()])# show the leading 'num_models' dimensionassertminibatches.shape==(num_models,64,1,28,28)# verify minibatch has leading dimension of size 'num_models'fromtorchimportvmappredictions1_vmap=vmap(fmodel)(params,buffers,minibatches)# verify the ``vmap`` predictions match theasserttorch.allclose(predictions1_vmap,torch.stack(predictions_diff_minibatch_loop),atol=1e-3,rtol=1e-5)
[10, 10, 10, 10, 10, 10]

Option 2: get predictions using the same minibatch of data.

vmap has anin_dims argument that specifies which dimensions to map over.By usingNone, we tellvmap we want the same minibatch to apply for all ofthe 10 models.

predictions2_vmap=vmap(fmodel,in_dims=(0,0,None))(params,buffers,minibatch)asserttorch.allclose(predictions2_vmap,torch.stack(predictions2),atol=1e-3,rtol=1e-5)

A quick note: there are limitations around what types of functions can betransformed byvmap. The best functions to transform are ones that are purefunctions: a function where the outputs are only determined by the inputsthat have no side effects (e.g. mutation).vmap is unable to handle mutationof arbitrary Python data structures, but it is able to handle many in-placePyTorch operations.

Performance#

Curious about performance numbers? Here’s how the numbers look.

fromtorch.utils.benchmarkimportTimerwithout_vmap=Timer(stmt="[model(minibatch) for model, minibatch in zip(models, minibatches)]",globals=globals())with_vmap=Timer(stmt="vmap(fmodel)(params, buffers, minibatches)",globals=globals())print(f'Predictions without vmap{without_vmap.timeit(100)}')print(f'Predictions with vmap{with_vmap.timeit(100)}')
Predictions without vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7f432df9d120>[model(minibatch) for model, minibatch in zip(models, minibatches)]  1.32 ms  1 measurement, 100 runs , 1 threadPredictions with vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7f432df4f070>vmap(fmodel)(params, buffers, minibatches)  530.98 us  1 measurement, 100 runs , 1 thread

There’s a large speedup usingvmap!

In general, vectorization withvmap should be faster than running a functionin a for-loop and competitive with manual batching. There are some exceptionsthough, like if we haven’t implemented thevmap rule for a particularoperation or if the underlying kernels weren’t optimized for older hardware(GPUs). If you see any of these cases, please let us know by opening an issueon GitHub.

Total running time of the script: (0 minutes 0.794 seconds)