torch.func.vjp#
- torch.func.vjp(func,*primals,has_aux=False)[source]#
Standing for the vector-Jacobian product, returns a tuple containing theresults of
funcapplied toprimalsand a function that, whengivencotangents, computes the reverse-mode Jacobian offuncwithrespect toprimalstimescotangents.- Parameters:
func (Callable) – A Python function that takes one or more arguments. Mustreturn one or more Tensors.
primals (Tensors) – Positional arguments to
functhat must all beTensors. The returned function will also be computing thederivative with respect to these argumentshas_aux (bool) – Flag indicating that
funcreturns a(output,aux)tuple where the first element is the output ofthe function to be differentiated and the second element isother auxiliary objects that will not be differentiated.Default: False.
- Returns:
Returns a
(output,vjp_fn)tuple containing the output offuncapplied toprimalsand a function that computes the vjp offuncwith respect to allprimalsusing the cotangents passedto the returned function. Ifhas_auxisTrue, then instead returns a(output,vjp_fn,aux)tuple.The returnedvjp_fnfunction will return a tuple of each VJP.
When used in simple cases,
vjp()behaves the same asgrad()>>>x=torch.randn([5])>>>f=lambdax:x.sin().sum()>>>(_,vjpfunc)=torch.func.vjp(f,x)>>>grad=vjpfunc(torch.tensor(1.0))[0]>>>asserttorch.allclose(grad,torch.func.grad(f)(x))
However,
vjp()can support functions with multiple outputs bypassing in the cotangents for each of the outputs>>>x=torch.randn([5])>>>f=lambdax:(x.sin(),x.cos())>>>(_,vjpfunc)=torch.func.vjp(f,x)>>>vjps=vjpfunc((torch.ones([5]),torch.ones([5])))>>>asserttorch.allclose(vjps[0],x.cos()+-x.sin())
vjp()can even support outputs being Python structs>>>x=torch.randn([5])>>>f=lambdax:{"first":x.sin(),"second":x.cos()}>>>(_,vjpfunc)=torch.func.vjp(f,x)>>>cotangents={"first":torch.ones([5]),"second":torch.ones([5])}>>>vjps=vjpfunc(cotangents)>>>asserttorch.allclose(vjps[0],x.cos()+-x.sin())
The function returned by
vjp()will compute the partials withrespect to each of theprimals>>>x,y=torch.randn([5,4]),torch.randn([4,5])>>>(_,vjpfunc)=torch.func.vjp(torch.matmul,x,y)>>>cotangents=torch.randn([5,5])>>>vjps=vjpfunc(cotangents)>>>assertlen(vjps)==2>>>asserttorch.allclose(vjps[0],torch.matmul(cotangents,y.transpose(0,1)))>>>asserttorch.allclose(vjps[1],torch.matmul(x.transpose(0,1),cotangents))
primalsare the positional arguments forf. All kwargs use theirdefault value>>>x=torch.randn([5])>>>deff(x,scale=4.):>>>returnx*scale>>>>>>(_,vjpfunc)=torch.func.vjp(f,x)>>>vjps=vjpfunc(torch.ones_like(x))>>>asserttorch.allclose(vjps[0],torch.full(x.shape,4.0))
Note
Using PyTorch
torch.no_gradtogether withvjp.Case 1: Usingtorch.no_gradinside a function:>>>deff(x):>>>withtorch.no_grad():>>>c=x**2>>>returnx-c
In this case,
vjp(f)(x)will respect the innertorch.no_grad.Case 2: Using
vjpinsidetorch.no_gradcontext manager:>>>withtorch.no_grad():>>>vjp(f)(x)
In this case,
vjpwill respect the innertorch.no_grad, but not theouter one. This is becausevjpis a “function transform”: its resultshould not depend on the result of a context manager outside off.