torch.func.hessian#
- torch.func.hessian(func,argnums=0)[source]#
Computes the Hessian of
funcwith respect to the arg(s) at indexargnumvia a forward-over-reverse strategy.The forward-over-reverse strategy (composing
jacfwd(jacrev(func))) isa good default for good performance. It is possible to compute Hessiansthrough other compositions ofjacfwd()andjacrev()likejacfwd(jacfwd(func))orjacrev(jacrev(func)).- Parameters
- Returns
Returns a function that takes in the same inputs as
funcandreturns the Hessian offuncwith respect to the arg(s) atargnums.
Note
You may see this API error out with “forward-mode AD not implementedfor operator X”. If so, please file a bug report and we will prioritize it.An alternative is to use
jacrev(jacrev(func)), which has betteroperator coverage.A basic usage with a R^N -> R^1 function gives a N x N Hessian:
>>>fromtorch.funcimporthessian>>>deff(x):>>>returnx.sin().sum()>>>>>>x=torch.randn(5)>>>hess=hessian(f)(x)# equivalent to jacfwd(jacrev(f))(x)>>>asserttorch.allclose(hess,torch.diag(-x.sin()))