Rate this Page

torch.func.vmap#

torch.func.vmap(func,in_dims=0,out_dims=0,randomness='error',*,chunk_size=None)[source]#

vmap is the vectorizing map;vmap(func) returns a new function thatmapsfunc over some dimension of the inputs. Semantically, vmappushes the map into PyTorch operations called byfunc, effectivelyvectorizing those operations.

vmap is useful for handling batch dimensions: one can write a functionfunc that runs on examples and then lift it to a function that cantake batches of examples withvmap(func). vmap can also be used tocompute batched gradients when composed with autograd.

Note

torch.vmap() is aliased totorch.func.vmap() forconvenience. Use whichever one you’d like.

Parameters
  • func (function) – A Python function that takes one or more arguments.Must return one or more Tensors.

  • in_dims (int ornested structure) – Specifies which dimension of theinputs should be mapped over.in_dims should have astructure like the inputs. If thein_dim for a particularinput is None, then that indicates there is no map dimension.Default: 0.

  • out_dims (int orTuple[int]) – Specifies where the mapped dimensionshould appear in the outputs. Ifout_dims is a Tuple, thenit should have one element per output. Default: 0.

  • randomness (str) – Specifies whether the randomness in thisvmap should be the same or different across batches. If ‘different’,the randomness for each batch will be different. If ‘same’, therandomness will be the same across batches. If ‘error’, any calls torandom functions will error. Default: ‘error’. WARNING: this flagonly applies to random PyTorch operations and does not apply toPython’s random module or numpy randomness.

  • chunk_size (None orint) – If None (default), apply a single vmap over inputs.If not None, then compute the vmapchunk_size samples at a time.Note thatchunk_size=1 is equivalent to computing the vmap with a for-loop.If you run into memory issues computing the vmap, please try a non-None chunk_size.

Returns

Returns a new “batched” function. It takes the same inputs asfunc, except each input has an extra dimension at the indexspecified byin_dims. It takes returns the same outputs asfunc, except each output has an extra dimension at the indexspecified byout_dims.

Return type

Callable

One example of usingvmap() is to compute batched dot products. PyTorchdoesn’t provide a batchedtorch.dot API; instead of unsuccessfullyrummaging through docs, usevmap() to construct a new function.

>>>torch.dot# [D], [D] -> []>>>batched_dot=torch.func.vmap(torch.dot)# [N, D], [N, D] -> [N]>>>x,y=torch.randn(2,5),torch.randn(2,5)>>>batched_dot(x,y)

vmap() can be helpful in hiding batch dimensions, leading to a simplermodel authoring experience.

>>>batch_size,feature_size=3,5>>>weights=torch.randn(feature_size,requires_grad=True)>>>>>>defmodel(feature_vec):>>># Very simple linear model with activation>>>returnfeature_vec.dot(weights).relu()>>>>>>examples=torch.randn(batch_size,feature_size)>>>result=torch.vmap(model)(examples)

vmap() can also help vectorize computations that were previously difficultor impossible to batch. One example is higher-order gradient computation.The PyTorch autograd engine computes vjps (vector-Jacobian products).Computing a full Jacobian matrix for some function f: R^N -> R^N usuallyrequires N calls toautograd.grad, one per Jacobian row. Usingvmap(),we can vectorize the whole computation, computing the Jacobian in a singlecall toautograd.grad.

>>># Setup>>>N=5>>>f=lambdax:x**2>>>x=torch.randn(N,requires_grad=True)>>>y=f(x)>>>I_N=torch.eye(N)>>>>>># Sequential approach>>>jacobian_rows=[torch.autograd.grad(y,x,v,retain_graph=True)[0]>>>forvinI_N.unbind()]>>>jacobian=torch.stack(jacobian_rows)>>>>>># vectorized gradient computation>>>defget_vjp(v):>>>returntorch.autograd.grad(y,x,v)>>>jacobian=torch.vmap(get_vjp)(I_N)

vmap() can also be nested, producing an output with multiple batched dimensions

>>>torch.dot# [D], [D] -> []>>>batched_dot=torch.vmap(...torch.vmap(torch.dot)...)# [N1, N0, D], [N1, N0, D] -> [N1, N0]>>>x,y=torch.randn(2,3,5),torch.randn(2,3,5)>>>batched_dot(x,y)# tensor of size [2, 3]

If the inputs are not batched along the first dimension,in_dims specifiesthe dimension that each inputs are batched along as

>>>torch.dot# [N], [N] -> []>>>batched_dot=torch.vmap(torch.dot,in_dims=1)# [N, D], [N, D] -> [D]>>>x,y=torch.randn(2,5),torch.randn(2,5)>>>batched_dot(...x,y...)# output is [5] instead of [2] if batched along the 0th dimension

If there are multiple inputs each of which is batched along different dimensions,in_dims must be a tuple with the batch dimension for each input as

>>>torch.dot# [D], [D] -> []>>>batched_dot=torch.vmap(torch.dot,in_dims=(0,None))# [N, D], [D] -> [N]>>>x,y=torch.randn(2,5),torch.randn(5)>>>batched_dot(...x,y...)# second arg doesn't have a batch dim because in_dim[1] was None

If the input is a Python struct,in_dims must be a tuple containing a structmatching the shape of the input:

>>>f=lambdadict:torch.dot(dict["x"],dict["y"])>>>x,y=torch.randn(2,5),torch.randn(5)>>>input={"x":x,"y":y}>>>batched_dot=torch.vmap(f,in_dims=({"x":0,"y":None},))>>>batched_dot(input)

By default, the output is batched along the first dimension. However, it can be batchedalong any dimension by usingout_dims

>>>f=lambdax:x**2>>>x=torch.randn(2,5)>>>batched_pow=torch.vmap(f,out_dims=1)>>>batched_pow(x)# [5, 2]

For any function that uses kwargs, the returned function will not batch the kwargs but willaccept kwargs

>>>x=torch.randn([2,5])>>>deffn(x,scale=4.):>>>returnx*scale>>>>>>batched_pow=torch.vmap(fn)>>>asserttorch.allclose(batched_pow(x),x*4)>>>batched_pow(x,scale=x)# scale is not batched, output has shape [2, 2, 5]

Note

vmap does not provide general autobatching or handle variable-lengthsequences out of the box.