jax.numpy.gradient
Contents
jax.numpy.gradient#
- jax.numpy.gradient(f,*varargs,axis=None,edge_order=None)[source]#
Compute the numerical gradient of a sampled function.
JAX implementation of
numpy.gradient().The gradient in
jnp.gradientis computed using second-order finitedifferences across the array of sampled function values. This should notbe confused withjax.grad(), which computes a precise gradient ofa callable function viaautomatic differentiation.- Parameters:
f (ArrayLike) –N-dimensional array of function values.
varargs (ArrayLike) –
optional list of scalars or arrays specifying spacing offunction evaluations. Options are:
not specified: unit spacing in all dimensions.
a single scalar: constant spacing in all dimensions.
N values: specify different spacing in each dimension:
scalar values indicate constant spacing in that dimension.
array values must match the length of the corresponding dimension,and specify the coordinates at which
fis evaluated.
edge_order (int |None) – not implemented in JAX
axis (int |Sequence[int]|None) – integer or tuple of integers specifying the axis along whichto compute the gradient. If None (default) calculates the gradientalong all axes.
- Returns:
an array or tuple of arrays containing the numerical gradient alongeach specified axis.
- Return type:
See also
jax.grad(): automatic differentiation of a function with a single output.
Examples
Comparing numerical and automatic differentiation of a simple function:
>>>deff(x):...returnjnp.sin(x)*jnp.exp(-x/4)...>>>defgradf_exact(x):...# exact analytical gradient of f(x)...return-f(x)/4+jnp.cos(x)*jnp.exp(-x/4)...>>>x=jnp.linspace(0,5,10)
>>>withjnp.printoptions(precision=2,suppress=True):...print("numerical gradient:",jnp.gradient(f(x),x))...print("automatic gradient:",jax.vmap(jax.grad(f))(x))...print("exact gradient: ",gradf_exact(x))...numerical gradient: [ 0.83 0.61 0.18 -0.2 -0.43 -0.49 -0.39 -0.21 -0.02 0.08]automatic gradient: [ 1. 0.62 0.17 -0.23 -0.46 -0.51 -0.41 -0.21 -0.01 0.15]exact gradient: [ 1. 0.62 0.17 -0.23 -0.46 -0.51 -0.41 -0.21 -0.01 0.15]
Notice that, as expected, the numerical gradient has some approximation errorcompared to the automatic gradient computed via
jax.grad().
