Rate this Page

torch.func.hessian#

torch.func.hessian(func,argnums=0)[source]#

Computes the Hessian offunc with respect to the arg(s) at indexargnum via a forward-over-reverse strategy.

The forward-over-reverse strategy (composingjacfwd(jacrev(func))) isa good default for good performance. It is possible to compute Hessiansthrough other compositions ofjacfwd() andjacrev() likejacfwd(jacfwd(func)) orjacrev(jacrev(func)).

Parameters
  • func (function) – A Python function that takes one or more arguments,one of which must be a Tensor, and returns one or more Tensors

  • argnums (int orTuple[int]) – Optional, integer or tuple of integers,saying which arguments to get the Hessian with respect to.Default: 0.

Returns

Returns a function that takes in the same inputs asfunc andreturns the Hessian offunc with respect to the arg(s) atargnums.

Note

You may see this API error out with “forward-mode AD not implementedfor operator X”. If so, please file a bug report and we will prioritize it.An alternative is to usejacrev(jacrev(func)), which has betteroperator coverage.

A basic usage with a R^N -> R^1 function gives a N x N Hessian:

>>>fromtorch.funcimporthessian>>>deff(x):>>>returnx.sin().sum()>>>>>>x=torch.randn(5)>>>hess=hessian(f)(x)# equivalent to jacfwd(jacrev(f))(x)>>>asserttorch.allclose(hess,torch.diag(-x.sin()))