Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.gradient

Contents

jax.numpy.gradient#

jax.numpy.gradient(f,*varargs,axis=None,edge_order=None)[source]#

Compute the numerical gradient of a sampled function.

JAX implementation ofnumpy.gradient().

The gradient injnp.gradient is computed using second-order finitedifferences across the array of sampled function values. This should notbe confused withjax.grad(), which computes a precise gradient ofa callable function viaautomatic differentiation.

Parameters:
  • f (ArrayLike) –N-dimensional array of function values.

  • varargs (ArrayLike) –

    optional list of scalars or arrays specifying spacing offunction evaluations. Options are:

    • not specified: unit spacing in all dimensions.

    • a single scalar: constant spacing in all dimensions.

    • N values: specify different spacing in each dimension:

      • scalar values indicate constant spacing in that dimension.

      • array values must match the length of the corresponding dimension,and specify the coordinates at whichf is evaluated.

  • edge_order (int |None) – not implemented in JAX

  • axis (int |Sequence[int]|None) – integer or tuple of integers specifying the axis along whichto compute the gradient. If None (default) calculates the gradientalong all axes.

Returns:

an array or tuple of arrays containing the numerical gradient alongeach specified axis.

Return type:

Array |list[Array]

See also

  • jax.grad(): automatic differentiation of a function with a single output.

Examples

Comparing numerical and automatic differentiation of a simple function:

>>>deff(x):...returnjnp.sin(x)*jnp.exp(-x/4)...>>>defgradf_exact(x):...# exact analytical gradient of f(x)...return-f(x)/4+jnp.cos(x)*jnp.exp(-x/4)...>>>x=jnp.linspace(0,5,10)
>>>withjnp.printoptions(precision=2,suppress=True):...print("numerical gradient:",jnp.gradient(f(x),x))...print("automatic gradient:",jax.vmap(jax.grad(f))(x))...print("exact gradient:    ",gradf_exact(x))...numerical gradient: [ 0.83  0.61  0.18 -0.2  -0.43 -0.49 -0.39 -0.21 -0.02  0.08]automatic gradient: [ 1.    0.62  0.17 -0.23 -0.46 -0.51 -0.41 -0.21 -0.01  0.15]exact gradient:     [ 1.    0.62  0.17 -0.23 -0.46 -0.51 -0.41 -0.21 -0.01  0.15]

Notice that, as expected, the numerical gradient has some approximation errorcompared to the automatic gradient computed viajax.grad().

Contents

[8]ページ先頭

©2009-2025 Movatter.jp