torch.func.grad_and_value#
- torch.func.grad_and_value(func,argnums=0,has_aux=False)[source]#
Returns a function to compute a tuple of the gradient and primal, orforward, computation.
- Parameters
func (Callable) – A Python function that takes one or more arguments.Must return a single-element Tensor. If specified
has_auxequalsTrue, function can return a tuple of single-elementTensor and other auxiliary objects:(output,aux).argnums (int orTuple[int]) – Specifies arguments to compute gradientswith respect to.
argnumscan be single integer or tuple ofintegers. Default: 0.has_aux (bool) – Flag indicating that
funcreturns a tensor andother auxiliary objects:(output,aux). Default: False.
- Returns
Function to compute a tuple of gradients with respect to its inputsand the forward computation. By default, the output of the function isa tuple of the gradient tensor(s) with respect to the first argumentand the primal computation. If specified
has_auxequalsTrue, tuple of gradients and tuple of the forward computation withoutput auxiliary objects is returned. Ifargnumsis a tuple ofintegers, a tuple of a tuple of the output gradients with respect toeachargnumsvalue and the forward computation is returned.- Return type
See
grad()for examples