torch.vmap#
- torch.vmap(func,in_dims=0,out_dims=0,randomness='error',*,chunk_size=None)[source]#
vmap is the vectorizing map;
vmap(func)returns a new function thatmapsfuncover some dimension of the inputs. Semantically, vmappushes the map into PyTorch operations called byfunc, effectivelyvectorizing those operations.vmap is useful for handling batch dimensions: one can write a function
functhat runs on examples and then lift it to a function that cantake batches of examples withvmap(func). vmap can also be used tocompute batched gradients when composed with autograd.Note
torch.vmap()is aliased totorch.func.vmap()forconvenience. Use whichever one you’d like.- Parameters
func (function) – A Python function that takes one or more arguments.Must return one or more Tensors.
in_dims (int ornested structure) – Specifies which dimension of theinputs should be mapped over.
in_dimsshould have astructure like the inputs. If thein_dimfor a particularinput is None, then that indicates there is no map dimension.Default: 0.out_dims (int orTuple[int]) – Specifies where the mapped dimensionshould appear in the outputs. If
out_dimsis a Tuple, thenit should have one element per output. Default: 0.randomness (str) – Specifies whether the randomness in thisvmap should be the same or different across batches. If ‘different’,the randomness for each batch will be different. If ‘same’, therandomness will be the same across batches. If ‘error’, any calls torandom functions will error. Default: ‘error’. WARNING: this flagonly applies to random PyTorch operations and does not apply toPython’s random module or numpy randomness.
chunk_size (None orint) – If None (default), apply a single vmap over inputs.If not None, then compute the vmap
chunk_sizesamples at a time.Note thatchunk_size=1is equivalent to computing the vmap with a for-loop.If you run into memory issues computing the vmap, please try a non-None chunk_size.
- Returns
Returns a new “batched” function. It takes the same inputs as
func, except each input has an extra dimension at the indexspecified byin_dims. It takes returns the same outputs asfunc, except each output has an extra dimension at the indexspecified byout_dims.- Return type
One example of using
vmap()is to compute batched dot products. PyTorchdoesn’t provide a batchedtorch.dotAPI; instead of unsuccessfullyrummaging through docs, usevmap()to construct a new function.>>>torch.dot# [D], [D] -> []>>>batched_dot=torch.func.vmap(torch.dot)# [N, D], [N, D] -> [N]>>>x,y=torch.randn(2,5),torch.randn(2,5)>>>batched_dot(x,y)
vmap()can be helpful in hiding batch dimensions, leading to a simplermodel authoring experience.>>>batch_size,feature_size=3,5>>>weights=torch.randn(feature_size,requires_grad=True)>>>>>>defmodel(feature_vec):>>># Very simple linear model with activation>>>returnfeature_vec.dot(weights).relu()>>>>>>examples=torch.randn(batch_size,feature_size)>>>result=torch.vmap(model)(examples)
vmap()can also help vectorize computations that were previously difficultor impossible to batch. One example is higher-order gradient computation.The PyTorch autograd engine computes vjps (vector-Jacobian products).Computing a full Jacobian matrix for some function f: R^N -> R^N usuallyrequires N calls toautograd.grad, one per Jacobian row. Usingvmap(),we can vectorize the whole computation, computing the Jacobian in a singlecall toautograd.grad.>>># Setup>>>N=5>>>f=lambdax:x**2>>>x=torch.randn(N,requires_grad=True)>>>y=f(x)>>>I_N=torch.eye(N)>>>>>># Sequential approach>>>jacobian_rows=[torch.autograd.grad(y,x,v,retain_graph=True)[0]>>>forvinI_N.unbind()]>>>jacobian=torch.stack(jacobian_rows)>>>>>># vectorized gradient computation>>>defget_vjp(v):>>>returntorch.autograd.grad(y,x,v)>>>jacobian=torch.vmap(get_vjp)(I_N)
vmap()can also be nested, producing an output with multiple batched dimensions>>>torch.dot# [D], [D] -> []>>>batched_dot=torch.vmap(...torch.vmap(torch.dot)...)# [N1, N0, D], [N1, N0, D] -> [N1, N0]>>>x,y=torch.randn(2,3,5),torch.randn(2,3,5)>>>batched_dot(x,y)# tensor of size [2, 3]
If the inputs are not batched along the first dimension,
in_dimsspecifiesthe dimension that each inputs are batched along as>>>torch.dot# [N], [N] -> []>>>batched_dot=torch.vmap(torch.dot,in_dims=1)# [N, D], [N, D] -> [D]>>>x,y=torch.randn(2,5),torch.randn(2,5)>>>batched_dot(...x,y...)# output is [5] instead of [2] if batched along the 0th dimension
If there are multiple inputs each of which is batched along different dimensions,
in_dimsmust be a tuple with the batch dimension for each input as>>>torch.dot# [D], [D] -> []>>>batched_dot=torch.vmap(torch.dot,in_dims=(0,None))# [N, D], [D] -> [N]>>>x,y=torch.randn(2,5),torch.randn(5)>>>batched_dot(...x,y...)# second arg doesn't have a batch dim because in_dim[1] was None
If the input is a Python struct,
in_dimsmust be a tuple containing a structmatching the shape of the input:>>>f=lambdadict:torch.dot(dict["x"],dict["y"])>>>x,y=torch.randn(2,5),torch.randn(5)>>>input={"x":x,"y":y}>>>batched_dot=torch.vmap(f,in_dims=({"x":0,"y":None},))>>>batched_dot(input)
By default, the output is batched along the first dimension. However, it can be batchedalong any dimension by using
out_dims>>>f=lambdax:x**2>>>x=torch.randn(2,5)>>>batched_pow=torch.vmap(f,out_dims=1)>>>batched_pow(x)# [5, 2]
For any function that uses kwargs, the returned function will not batch the kwargs but willaccept kwargs
>>>x=torch.randn([2,5])>>>deffn(x,scale=4.):>>>returnx*scale>>>>>>batched_pow=torch.vmap(fn)>>>asserttorch.allclose(batched_pow(x),x*4)>>>batched_pow(x,scale=x)# scale is not batched, output has shape [2, 2, 5]
Note
vmap does not provide general autobatching or handle variable-lengthsequences out of the box.