Rate this Page

UX Limitations#

Created On: Jun 12, 2025 | Last Updated On: Jun 12, 2025

torch.func, likeJAX, has restrictions aroundwhat can be transformed. In general, JAX’s limitations are that transformsonly work with pure functions: that is, functions where the output is completelydetermined by the input and that do not involve side effects (like mutation).

We have a similar guarantee: our transforms work well with pure functions.However, we do support certain in-place operations. On one hand, writing codecompatible with function transforms may involve changing how you write PyTorchcode, on the other hand, you may find that our transforms let you express thingsthat were previously difficult to express in PyTorch.

General limitations#

All torch.func transforms share a limitation in that a function should notassign to global variables. Instead, all outputs to a function must be returnedfrom the function. This restriction comes from how torch.func is implemented:each transform wraps Tensor inputs in special torch.func Tensor subclassesthat facilitate the transform.

So, instead of the following:

importtorchfromtorch.funcimportgrad# Don't do thisintermediate=Nonedeff(x):globalintermediateintermediate=x.sin()z=intermediate.sin()returnzx=torch.randn([])grad_x=grad(f)(x)

Please rewritef to returnintermediate:

deff(x):intermediate=x.sin()z=intermediate.sin()returnz,intermediategrad_x,intermediate=grad(f,has_aux=True)(x)

torch.autograd APIs#

If you are trying to use atorch.autograd API liketorch.autograd.gradortorch.autograd.backward inside of a function being transformed byvmap() or one of torch.func’s AD transforms (vjp(),jvp(),jacrev(),jacfwd()), the transform may not be able to transform over it.If it is unable to do so, you’ll receive an error message.

This is a fundamental design limitation in how PyTorch’s AD support is implementedand the reason why we designed the torch.func library. Please instead use the torch.funcequivalents of thetorch.autograd APIs:

  • torch.autograd.grad,Tensor.backward ->torch.func.vjp ortorch.func.grad

  • torch.autograd.functional.jvp ->torch.func.jvp

  • torch.autograd.functional.jacobian ->torch.func.jacrev ortorch.func.jacfwd

  • torch.autograd.functional.hessian ->torch.func.hessian

vmap limitations#

Note

vmap() is our most restrictive transform.The grad-related transforms (grad(),vjp(),jvp()) do nothave these limitations.jacfwd() (andhessian(), which isimplemented withjacfwd()) is a composition ofvmap() andjvp() so it also has these limitations.

vmap(func) is a transform that returns a function that mapsfunc oversome new dimension of each input Tensor. The mental model for vmap is that it islike running a for-loop: for pure functions (i.e. in the absence of sideeffects),vmap(f)(x) is equivalent to:

torch.stack([f(x_i)forx_iinx.unbind(0)])

Mutation: Arbitrary mutation of Python data structures#

In the presence of side effects,vmap() no longer acts like it is runninga for-loop. For example, the following function:

deff(x,list):list.pop()print("hello!")returnx.sum(0)x=torch.randn(3,1)lst=[0,1,2,3]result=vmap(f,in_dims=(0,None))(x,lst)

will print “hello!” once and pop only one element fromlst.

vmap() executesf a single time, so all side effects only happen once.

This is a consequence of how vmap is implemented. torch.func has a special,internal BatchedTensor class.vmap(f)(*inputs) takes all Tensor inputs,turns them into BatchedTensors, and callsf(*batched_tensor_inputs).BatchedTensor overrides the PyTorch API to produce batched (i.e. vectorized)behavior for each PyTorch operator.

Mutation: in-place PyTorch Operations#

You might be here due to receiving an error about vmap-incompatible in-placeoperations.vmap() will raise an error if it encounters an unsupported PyTorchin-place operation and it will succeed otherwise. Unsupported operationsare those that would cause a Tensor with more elements to be written to aTensor with fewer elements. Here’s an example of how this can occur:

deff(x,y):x.add_(y)returnxx=torch.randn(1)y=torch.randn(3,1)# When vmapped over, looks like it has shape [1]# Raises an error because `x` has fewer elements than `y`.vmap(f,in_dims=(None,0))(x,y)

x is a Tensor with one element,y is a Tensor with three elements.x+y has three elements (due to broadcasting), but attempting to writethree elements back intox, which only has one element, raises an errordue to attempting to write three elements into a Tensor with a single element.

There is no problem if the Tensor being written to is batched undervmap() (i.e. it is being vmapped over).

deff(x,y):x.add_(y)returnxx=torch.randn(3,1)y=torch.randn(3,1)expected=x+y# Does not raise an error because x is being vmapped over.vmap(f,in_dims=(0,0))(x,y)asserttorch.allclose(x,expected)

One common fix for this is to replace calls to factory functions withtheir “new_*” equivalent. For example:

To see why this helps, consider the following.

defdiag_embed(vec):assertvec.dim()==1result=torch.zeros(vec.shape[0],vec.shape[0])result.diagonal().copy_(vec)returnresultvecs=torch.tensor([[0.,1,2],[3.,4,5]])# RuntimeError: vmap: inplace arithmetic(self, *extra_args) is not possible ...vmap(diag_embed)(vecs)

Inside ofvmap(),result is a Tensor of shape [3, 3].However, althoughvec looks like it has shape [3],vec actually hasunderlying shape [2, 3].It is not possible to copyvec intoresult.diagonal(), which hasshape [3], because it has too many elements.

defdiag_embed(vec):assertvec.dim()==1result=vec.new_zeros(vec.shape[0],vec.shape[0])result.diagonal().copy_(vec)returnresultvecs=torch.tensor([[0.,1,2],[3.,4,5]])vmap(diag_embed)(vecs)

Replacingtorch.zeros() withTensor.new_zeros() makes it so thatresult has an underlying Tensor of shape [2, 3, 3], so it is now possibleto copyvec, which has underlying shape [2, 3], intoresult.diagonal().

Mutation: out= PyTorch Operations#

vmap() doesn’t support theout= keyword argument in PyTorch operations.It will error out gracefully if it encounters that in your code.

This is not a fundamental limitation; we could theoretically support this in thefuture but we have chosen not to for now.

Data-dependent Python control flow#

We don’t yet supportvmap over data-dependent control flow. Data-dependentcontrol flow is when the condition of an if-statement, while-loop, orfor-loop is a Tensor that is beingvmap’ed over. For example, thefollowing will raise an error message:

defrelu(x):ifx>0:returnxreturn0x=torch.randn(3)vmap(relu)(x)

However, any control flow that is not dependent on the values invmap’edtensors will work:

defcustom_dot(x):ifx.dim()==1:returntorch.dot(x,x)return(x*x).sum()x=torch.randn(3)vmap(custom_dot)(x)

JAX supports transforming overdata-dependent control flowusing special control flow operators (e.g.jax.lax.cond,jax.lax.while_loop).We’re investigating adding equivalents of those to PyTorch.

Data-dependent operations (.item())#

We do not (and will not) support vmap over a user-defined function that calls.item() on a Tensor. For example, the following will raise an error message:

deff(x):returnx.item()x=torch.randn(3)vmap(f)(x)

Please try to rewrite your code to not use.item() calls.

You may also encounter an error message about using.item() but you mightnot have used it. In those cases, it is possible that PyTorch internally iscalling.item() – please file an issue on GitHub and we’ll fixPyTorch internals.

Dynamic shape operations (nonzero and friends)#

vmap(f) requires thatf applied to every “example” in your inputreturns a Tensor with the same shape. Operations such astorch.nonzero,torch.is_nonzero are not supported and will error as a result.

To see why, consider the following example:

xs=torch.tensor([[0,1,2],[0,0,3]])vmap(torch.nonzero)(xs)

torch.nonzero(xs[0]) returns a Tensor of shape 2;buttorch.nonzero(xs[1]) returns a Tensor of shape 1.We are unable to construct a single Tensor as an output;the output would need to be a ragged Tensor (and PyTorch does not yet havethe concept of a ragged Tensor).

Randomness#

The user’s intention when calling a random operation can be unclear. Specifically, some users may wantthe random behavior to be the same across batches while others may want it to differ across batches.To address this,vmap takes a randomness flag.

The flag can only be passed to vmap and can take on 3 values, “error,” “different,” or “same,” defaultingto error. Under “error” mode, any call to a random function will produce an error asking the user to useone of the other two flags based on their use case.

Under “different” randomness, elements in a batch produce different random values. For instance,

defadd_noise(x):y=torch.randn(())# y will be different across the batchreturnx+yx=torch.ones(3)result=vmap(add_noise,randomness="different")(x)# we get 3 different values

Under “same” randomness, elements in a batch produce same random values. For instance,

defadd_noise(x):y=torch.randn(())# y will be the same across the batchreturnx+yx=torch.ones(3)result=vmap(add_noise,randomness="same")(x)# we get the same value, repeated 3 times

Warning

Our system only determine the randomness behavior of PyTorch operators and cannot control thebehavior of other libraries, like numpy. This is similar to JAX’s limitations with their solutions

Note

Multiple vmap calls using either type of supported randomness will not producethe same results. Like with standard PyTorch, a user can get randomness reproducibility througheither usingtorch.manual_seed() outside of vmap or by using generators.

Note

Finally, our randomness differs from JAX because we aren’t using a stateless PRNG, in part because PyTorchdoesn’t have full support for a stateless PRNG. Instead, we’ve introduced a flag system to allow for themost common forms of randomness that we see. If your use case does not fit these forms of randomness, pleasefile an issue.