Rate this Page

torch.func.vjp#

torch.func.vjp(func,*primals,has_aux=False)[source]#

Standing for the vector-Jacobian product, returns a tuple containing theresults offunc applied toprimals and a function that, whengivencotangents, computes the reverse-mode Jacobian offunc withrespect toprimals timescotangents.

Parameters:
  • func (Callable) – A Python function that takes one or more arguments. Mustreturn one or more Tensors.

  • primals (Tensors) – Positional arguments tofunc that must all beTensors. The returned function will also be computing thederivative with respect to these arguments

  • has_aux (bool) – Flag indicating thatfunc returns a(output,aux) tuple where the first element is the output ofthe function to be differentiated and the second element isother auxiliary objects that will not be differentiated.Default: False.

Returns:

Returns a(output,vjp_fn) tuple containing the output offuncapplied toprimals and a function that computes the vjp offunc with respect to allprimals using the cotangents passedto the returned function. Ifhas_auxisTrue, then instead returns a(output,vjp_fn,aux) tuple.The returnedvjp_fn function will return a tuple of each VJP.

When used in simple cases,vjp() behaves the same asgrad()

>>>x=torch.randn([5])>>>f=lambdax:x.sin().sum()>>>(_,vjpfunc)=torch.func.vjp(f,x)>>>grad=vjpfunc(torch.tensor(1.0))[0]>>>asserttorch.allclose(grad,torch.func.grad(f)(x))

However,vjp() can support functions with multiple outputs bypassing in the cotangents for each of the outputs

>>>x=torch.randn([5])>>>f=lambdax:(x.sin(),x.cos())>>>(_,vjpfunc)=torch.func.vjp(f,x)>>>vjps=vjpfunc((torch.ones([5]),torch.ones([5])))>>>asserttorch.allclose(vjps[0],x.cos()+-x.sin())

vjp() can even support outputs being Python structs

>>>x=torch.randn([5])>>>f=lambdax:{"first":x.sin(),"second":x.cos()}>>>(_,vjpfunc)=torch.func.vjp(f,x)>>>cotangents={"first":torch.ones([5]),"second":torch.ones([5])}>>>vjps=vjpfunc(cotangents)>>>asserttorch.allclose(vjps[0],x.cos()+-x.sin())

The function returned byvjp() will compute the partials withrespect to each of theprimals

>>>x,y=torch.randn([5,4]),torch.randn([4,5])>>>(_,vjpfunc)=torch.func.vjp(torch.matmul,x,y)>>>cotangents=torch.randn([5,5])>>>vjps=vjpfunc(cotangents)>>>assertlen(vjps)==2>>>asserttorch.allclose(vjps[0],torch.matmul(cotangents,y.transpose(0,1)))>>>asserttorch.allclose(vjps[1],torch.matmul(x.transpose(0,1),cotangents))

primals are the positional arguments forf. All kwargs use theirdefault value

>>>x=torch.randn([5])>>>deff(x,scale=4.):>>>returnx*scale>>>>>>(_,vjpfunc)=torch.func.vjp(f,x)>>>vjps=vjpfunc(torch.ones_like(x))>>>asserttorch.allclose(vjps[0],torch.full(x.shape,4.0))

Note

Using PyTorchtorch.no_grad together withvjp.Case 1: Usingtorch.no_grad inside a function:

>>>deff(x):>>>withtorch.no_grad():>>>c=x**2>>>returnx-c

In this case,vjp(f)(x) will respect the innertorch.no_grad.

Case 2: Usingvjp insidetorch.no_grad context manager:

>>>withtorch.no_grad():>>>vjp(f)(x)

In this case,vjp will respect the innertorch.no_grad, but not theouter one. This is becausevjp is a “function transform”: its resultshould not depend on the result of a context manager outside off.