torch.func.jacfwd#
- torch.func.jacfwd(func,argnums=0,has_aux=False,*,randomness='error')[source]#
Computes the Jacobian of
funcwith respect to the arg(s) at indexargnumusing forward-mode autodiff- Parameters
func (function) – A Python function that takes one or more arguments,one of which must be a Tensor, and returns one or more Tensors
argnums (int orTuple[int]) – Optional, integer or tuple of integers,saying which arguments to get the Jacobian with respect to.Default: 0.
has_aux (bool) – Flag indicating that
funcreturns a(output,aux)tuple where the first element is the output ofthe function to be differentiated and the second element isauxiliary objects that will not be differentiated.Default: False.randomness (str) – Flag indicating what type of randomness to use.See
vmap()for more detail. Allowed: “different”, “same”, “error”.Default: “error”
- Returns
Returns a function that takes in the same inputs as
funcandreturns the Jacobian offuncwith respect to the arg(s) atargnums. Ifhas_auxisTrue, then the returned functioninstead returns a(jacobian,aux)tuple wherejacobianis the Jacobian andauxis auxiliary objects returned byfunc.
Note
You may see this API error out with “forward-mode AD not implementedfor operator X”. If so, please file a bug report and we will prioritize it.An alternative is to use
jacrev(), which has better operator coverage.A basic usage with a pointwise, unary operation will give a diagonal arrayas the Jacobian
>>>fromtorch.funcimportjacfwd>>>x=torch.randn(5)>>>jacobian=jacfwd(torch.sin)(x)>>>expected=torch.diag(torch.cos(x))>>>asserttorch.allclose(jacobian,expected)
jacfwd()can be composed with vmap to produce batchedJacobians:>>>fromtorch.funcimportjacfwd,vmap>>>x=torch.randn(64,5)>>>jacobian=vmap(jacfwd(torch.sin))(x)>>>assertjacobian.shape==(64,5,5)
If you would like to compute the output of the function as well as thejacobian of the function, use the
has_auxflag to return the outputas an auxiliary object:>>>fromtorch.funcimportjacfwd>>>x=torch.randn(5)>>>>>>deff(x):>>>returnx.sin()>>>>>>defg(x):>>>result=f(x)>>>returnresult,result>>>>>>jacobian_f,f_x=jacfwd(g,has_aux=True)(x)>>>asserttorch.allclose(f_x,f(x))
Additionally,
jacrev()can be composed with itself orjacrev()to produce Hessians>>>fromtorch.funcimportjacfwd,jacrev>>>deff(x):>>>returnx.sin().sum()>>>>>>x=torch.randn(5)>>>hessian=jacfwd(jacrev(f))(x)>>>asserttorch.allclose(hessian,torch.diag(-x.sin()))
By default,
jacfwd()computes the Jacobian with respect to the firstinput. However, it can compute the Jacboian with respect to a differentargument by usingargnums:>>>fromtorch.funcimportjacfwd>>>deff(x,y):>>>returnx+y**2>>>>>>x,y=torch.randn(5),torch.randn(5)>>>jacobian=jacfwd(f,argnums=1)(x,y)>>>expected=torch.diag(2*y)>>>asserttorch.allclose(jacobian,expected)
Additionally, passing a tuple to
argnumswill compute the Jacobianwith respect to multiple arguments>>>fromtorch.funcimportjacfwd>>>deff(x,y):>>>returnx+y**2>>>>>>x,y=torch.randn(5),torch.randn(5)>>>jacobian=jacfwd(f,argnums=(0,1))(x,y)>>>expectedX=torch.diag(torch.ones_like(x))>>>expectedY=torch.diag(2*y)>>>asserttorch.allclose(jacobian[0],expectedX)>>>asserttorch.allclose(jacobian[1],expectedY)