Rate this Page

torch.func.grad_and_value#

torch.func.grad_and_value(func,argnums=0,has_aux=False)[source]#

Returns a function to compute a tuple of the gradient and primal, orforward, computation.

Parameters
  • func (Callable) – A Python function that takes one or more arguments.Must return a single-element Tensor. If specifiedhas_auxequalsTrue, function can return a tuple of single-elementTensor and other auxiliary objects:(output,aux).

  • argnums (int orTuple[int]) – Specifies arguments to compute gradientswith respect to.argnums can be single integer or tuple ofintegers. Default: 0.

  • has_aux (bool) – Flag indicating thatfunc returns a tensor andother auxiliary objects:(output,aux). Default: False.

Returns

Function to compute a tuple of gradients with respect to its inputsand the forward computation. By default, the output of the function isa tuple of the gradient tensor(s) with respect to the first argumentand the primal computation. If specifiedhas_aux equalsTrue, tuple of gradients and tuple of the forward computation withoutput auxiliary objects is returned. Ifargnums is a tuple ofintegers, a tuple of a tuple of the output gradients with respect toeachargnums value and the forward computation is returned.

Return type

Callable

Seegrad() for examples