jax.grad
Contents
jax.grad#
- jax.grad(fun,argnums=0,has_aux=False,holomorphic=False,allow_int=False,reduce_axes=())[source]#
Creates a function that evaluates the gradient of
fun.- Parameters:
fun (Callable) – Function to be differentiated. Its arguments at positions specified by
argnumsshould be arrays, scalars, or standard Python containers.Argument arrays in the positions specified byargnumsmust be ofinexact (i.e., floating-point or complex) type. Itshould return a scalar (which includes arrays with shape()but notarrays with shape(1,)etc.)argnums (int |Sequence[int]) – Optional, integer or sequence of integers. Specifies whichpositional argument(s) to differentiate with respect to (default 0).
has_aux (bool) – Optional, bool. Indicates whether
funreturns a pair where thefirst element is considered the output of the mathematical function to bedifferentiated and the second element is auxiliary data. Default False.holomorphic (bool) – Optional, bool. Indicates whether
funis promised to beholomorphic. If True, inputs and outputs must be complex. Default False.allow_int (bool) – Optional, bool. Whether to allow differentiating withrespect to integer valued inputs. The gradient of an integer input willhave a trivial vector-space dtype (float0). Default False.
reduce_axes (Sequence[AxisName])
- Returns:
A function with the same arguments as
fun, that evaluates the gradientoffun. Ifargnumsis an integer then the gradient has the sameshape and type as the positional argument indicated by that integer. Ifargnums is a tuple of integers, the gradient is a tuple of values with thesame shapes and types as the corresponding arguments. Ifhas_auxis Truethen a pair of (gradient, auxiliary_data) is returned.- Return type:
Callable
For example:
>>>importjax>>>>>>grad_tanh=jax.grad(jax.numpy.tanh)>>>print(grad_tanh(0.2))0.961043
