Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.grad

Contents

jax.grad#

jax.grad(fun,argnums=0,has_aux=False,holomorphic=False,allow_int=False,reduce_axes=())[source]#

Creates a function that evaluates the gradient offun.

Parameters:
  • fun (Callable) – Function to be differentiated. Its arguments at positions specified byargnums should be arrays, scalars, or standard Python containers.Argument arrays in the positions specified byargnums must be ofinexact (i.e., floating-point or complex) type. Itshould return a scalar (which includes arrays with shape() but notarrays with shape(1,) etc.)

  • argnums (int |Sequence[int]) – Optional, integer or sequence of integers. Specifies whichpositional argument(s) to differentiate with respect to (default 0).

  • has_aux (bool) – Optional, bool. Indicates whetherfun returns a pair where thefirst element is considered the output of the mathematical function to bedifferentiated and the second element is auxiliary data. Default False.

  • holomorphic (bool) – Optional, bool. Indicates whetherfun is promised to beholomorphic. If True, inputs and outputs must be complex. Default False.

  • allow_int (bool) – Optional, bool. Whether to allow differentiating withrespect to integer valued inputs. The gradient of an integer input willhave a trivial vector-space dtype (float0). Default False.

  • reduce_axes (Sequence[AxisName])

Returns:

A function with the same arguments asfun, that evaluates the gradientoffun. Ifargnums is an integer then the gradient has the sameshape and type as the positional argument indicated by that integer. Ifargnums is a tuple of integers, the gradient is a tuple of values with thesame shapes and types as the corresponding arguments. Ifhas_aux is Truethen a pair of (gradient, auxiliary_data) is returned.

Return type:

Callable

For example:

>>>importjax>>>>>>grad_tanh=jax.grad(jax.numpy.tanh)>>>print(grad_tanh(0.2))0.961043
Contents

[8]ページ先頭

©2009-2025 Movatter.jp