Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.value_and_grad

jax.value_and_grad#

jax.value_and_grad(fun,argnums=0,has_aux=False,holomorphic=False,allow_int=False,reduce_axes=())[source]#

Create a function that evaluates bothfun and the gradient offun.

Parameters:
  • fun (Callable) – Function to be differentiated. Its arguments at positions specified byargnums should be arrays, scalars, or standard Python containers. Itshould return a scalar (which includes arrays with shape() but notarrays with shape(1,) etc.)

  • argnums (int |Sequence[int]) – Optional, integer or sequence of integers. Specifies whichpositional argument(s) to differentiate with respect to (default 0).

  • has_aux (bool) – Optional, bool. Indicates whetherfun returns a pair where thefirst element is considered the output of the mathematical function to bedifferentiated and the second element is auxiliary data. Default False.

  • holomorphic (bool) – Optional, bool. Indicates whetherfun is promised to beholomorphic. If True, inputs and outputs must be complex. Default False.

  • allow_int (bool) – Optional, bool. Whether to allow differentiating withrespect to integer valued inputs. The gradient of an integer input willhave a trivial vector-space dtype (float0). Default False.

  • reduce_axes (Sequence[AxisName])

Returns:

A function with the same arguments asfun that evaluates bothfunand the gradient offun and returns them as a pair (a two-elementtuple). Ifargnums is an integer then the gradient has the same shapeand type as the positional argument indicated by that integer. If argnums isa sequence of integers, the gradient is a tuple of values with the sameshapes and types as the corresponding arguments. Ifhas_aux is Truethen a tuple of ((value, auxiliary_data), gradient) is returned.

Return type:

Callable[…,tuple[Any, Any]]


[8]ページ先頭

©2009-2025 Movatter.jp