jax.numpy.diff
Contents
jax.numpy.diff#
- jax.numpy.diff(a,n=1,axis=-1,prepend=None,append=None)[source]#
Calculate n-th order difference between array elements along a given axis.
JAX implementation of
numpy.diff().The first order difference is computed by
a[i+1]-a[i], and the n-th orderdifference is computedntimes recursively.- Parameters:
a (ArrayLike) – input array. Must have
a.ndim>=1.n (int) – int, optional, default=1. Order of the difference. Specifies the numberof times the difference is computed. If n=0, no difference is computed andinput is returned as is.
axis (int) – int, optional, default=-1. Specifies the axis along which the differenceis computed. The difference is computed along
axis-1by default.prepend (ArrayLike |None) – scalar or array, optional, default=None. Specifies the values to beprepended along
axisbefore computing the difference.append (ArrayLike |None) – scalar or array, optional, default=None. Specifies the values to beappended along
axisbefore computing the difference.
- Returns:
An array containing the n-th order difference between the elements of
a.- Return type:
See also
jax.numpy.ediff1d(): Computes the differences between consecutiveelements of an array.jax.numpy.cumsum(): Computes the cumulative sum of the elements ofthe array along a given axis.jax.numpy.gradient(): Computes the gradient of an N-dimensional array.
Examples
jnp.diffcomputes the first order difference alongaxis, by default.>>>a=jnp.array([[1,5,2,9],...[3,8,7,4]])>>>jnp.diff(a)Array([[ 4, -3, 7], [ 5, -1, -3]], dtype=int32)
When
n=2, second order difference is computed alongaxis.>>>jnp.diff(a,n=2)Array([[-7, 10], [-6, -2]], dtype=int32)
When
prepend=2, it is prepended toaalongaxisbefore computingthe difference.>>>jnp.diff(a,prepend=2)Array([[-1, 4, -3, 7], [ 1, 5, -1, -3]], dtype=int32)
When
append=jnp.array([[3],[1]]), it is appended toaalongaxisbefore computing the difference.>>>jnp.diff(a,append=jnp.array([[3],[1]]))Array([[ 4, -3, 7, -6], [ 5, -1, -3, -3]], dtype=int32)
