jax.value_and_grad
Contents
jax.value_and_grad#
- jax.value_and_grad(fun,argnums=0,has_aux=False,holomorphic=False,allow_int=False,reduce_axes=())[source]#
Create a function that evaluates both
funand the gradient offun.- Parameters:
fun (Callable) – Function to be differentiated. Its arguments at positions specified by
argnumsshould be arrays, scalars, or standard Python containers. Itshould return a scalar (which includes arrays with shape()but notarrays with shape(1,)etc.)argnums (int |Sequence[int]) – Optional, integer or sequence of integers. Specifies whichpositional argument(s) to differentiate with respect to (default 0).
has_aux (bool) – Optional, bool. Indicates whether
funreturns a pair where thefirst element is considered the output of the mathematical function to bedifferentiated and the second element is auxiliary data. Default False.holomorphic (bool) – Optional, bool. Indicates whether
funis promised to beholomorphic. If True, inputs and outputs must be complex. Default False.allow_int (bool) – Optional, bool. Whether to allow differentiating withrespect to integer valued inputs. The gradient of an integer input willhave a trivial vector-space dtype (float0). Default False.
reduce_axes (Sequence[AxisName])
- Returns:
A function with the same arguments as
funthat evaluates bothfunand the gradient offunand returns them as a pair (a two-elementtuple). Ifargnumsis an integer then the gradient has the same shapeand type as the positional argument indicated by that integer. If argnums isa sequence of integers, the gradient is a tuple of values with the sameshapes and types as the corresponding arguments. Ifhas_auxis Truethen a tuple of ((value, auxiliary_data), gradient) is returned.- Return type:
Callable[…,tuple[Any, Any]]
