Note
Go to the endto download the full example code.
torch.vmap#
This tutorial introduces torch.vmap, an autovectorizer for PyTorch operations.torch.vmap is a prototype feature and cannot handle a number of use cases;however, we would like to gather use cases for it to inform the design. If youare considering using torch.vmap or think it would be really cool for something,please contact us atpytorch/pytorch#42368.
So, what is vmap?#
vmap is a higher-order function. It accepts a functionfunc and returns a newfunction that mapsfunc over some dimension of the inputs. It is highlyinspired by JAX’s vmap.
Semantically, vmap pushes the “map” into PyTorch operations called byfunc,effectively vectorizing those operations.
importtorch# NB: vmap is only available on nightly builds of PyTorch.# You can download one at pytorch.org if you're interested in testing it out.fromtorchimportvmap
The first use case for vmap is making it easier to handlebatch dimensions in your code. One can write a functionfuncthat runs on examples and then lift it to a function that cantake batches of examples withvmap(func).func howeveris subject to many restrictions:
it must be functional (one cannot mutate a Python data structureinside of it), with the exception of in-place PyTorch operations.
batches of examples must be provided as Tensors. This means thatvmap doesn’t handle variable-length sequences out of the box.
One example of usingvmap is to compute batched dot products. PyTorchdoesn’t provide a batchedtorch.dot API; instead of unsuccessfullyrummaging through docs, usevmap to construct a new function:
torch.dot# [D], [D] -> []batched_dot=torch.vmap(torch.dot)# [N, D], [N, D] -> [N]x,y=torch.randn(2,5),torch.randn(2,5)batched_dot(x,y)
tensor([-2.6925, 0.5633])
vmap can be helpful in hiding batch dimensions, leading to a simplermodel authoring experience.
batch_size,feature_size=3,5weights=torch.randn(feature_size,requires_grad=True)# Note that model doesn't work with a batch of feature vectors because# torch.dot must take 1D tensors. It's pretty easy to rewrite this# to use `torch.matmul` instead, but if we didn't want to do that or if# the code is more complicated (e.g., does some advanced indexing# shenanigins), we can simply call `vmap`. `vmap` batches over ALL# inputs, unless otherwise specified (with the in_dims argument,# please see the documentation for more details).defmodel(feature_vec):# Very simple linear model with activationreturnfeature_vec.dot(weights).relu()examples=torch.randn(batch_size,feature_size)result=torch.vmap(model)(examples)expected=torch.stack([model(example)forexampleinexamples.unbind()])asserttorch.allclose(result,expected)
vmap can also help vectorize computations that were previously difficultor impossible to batch. This bring us to our second use case: batchedgradient computation.
The PyTorch autograd engine computes vjps (vector-Jacobian products).Using vmap, we can compute (batched vector) - jacobian products.
One example of this is computing a full Jacobian matrix (this can also beapplied to computing a full Hessian matrix).Computing a full Jacobian matrix for some function f: R^N -> R^N usuallyrequires N calls toautograd.grad, one per Jacobian row.
# SetupN=5deff(x):returnx**2x=torch.randn(N,requires_grad=True)y=f(x)basis_vectors=torch.eye(N)# Sequential approachjacobian_rows=[torch.autograd.grad(y,x,v,retain_graph=True)[0]forvinbasis_vectors.unbind()]jacobian=torch.stack(jacobian_rows)# Using `vmap`, we can vectorize the whole computation, computing the# Jacobian in a single call to `autograd.grad`.defget_vjp(v):returntorch.autograd.grad(y,x,v)[0]jacobian_vmap=vmap(get_vjp)(basis_vectors)asserttorch.allclose(jacobian_vmap,jacobian)
The third main use case for vmap is computing per-sample-gradients.This is something that the vmap prototype cannot handle performantlyright now. We’re not sure what the API for computing per-sample-gradientsshould be, but if you have ideas, please comment inpytorch/pytorch#7786.
defmodel(sample,weight):# do something...returntorch.dot(sample,weight)defgrad_sample(sample):returntorch.autograd.functional.vjp(lambdaweight:model(sample),weight)[1]# The following doesn't actually work in the vmap prototype. But it# could be an API for computing per-sample-gradients.# batch_of_samples = torch.randn(64, 5)# vmap(grad_sample)(batch_of_samples)
Total running time of the script: (0 minutes 0.132 seconds)