jax.numpy.ediff1d
Contents
jax.numpy.ediff1d#
- jax.numpy.ediff1d(ary,to_end=None,to_begin=None)[source]#
Compute the differences of the elements of the flattened array.
JAX implementation of
numpy.ediff1d().- Parameters:
ary (ArrayLike) – input array or scalar.
to_end (ArrayLike |None) – scalar or array, optional, default=None. Specifies the numbers toappend to the resulting array.
to_begin (ArrayLike |None) – scalar or array, optional, default=None. Specifies the numbers toprepend to the resulting array.
- Returns:
An array containing the differences between the elements of the input array.
- Return type:
Note
Unlike NumPy’s implementation of ediff1d,
jax.numpy.ediff1d()willnot issue an error if castingto_endorto_beginto the type ofaryloses precision.See also
jax.numpy.diff(): Computes the n-th order difference between elementsof the array along a given axis.jax.numpy.cumsum(): Computes the cumulative sum of the elements ofthe array along a given axis.jax.numpy.gradient(): Computes the gradient of an N-dimensional array.
Examples
>>>a=jnp.array([2,3,5,9,1,4])>>>jnp.ediff1d(a)Array([ 1, 2, 4, -8, 3], dtype=int32)>>>jnp.ediff1d(a,to_begin=-10)Array([-10, 1, 2, 4, -8, 3], dtype=int32)>>>jnp.ediff1d(a,to_end=jnp.array([20,30]))Array([ 1, 2, 4, -8, 3, 20, 30], dtype=int32)>>>jnp.ediff1d(a,to_begin=-10,to_end=jnp.array([20,30]))Array([-10, 1, 2, 4, -8, 3, 20, 30], dtype=int32)
For array with
ndim>1, the differences are computed after flatteningthe input array.>>>a1=jnp.array([[2,-1,4,7],...[3,5,-6,9]])>>>jnp.ediff1d(a1)Array([ -3, 5, 3, -4, 2, -11, 15], dtype=int32)>>>a2=jnp.array([2,-1,4,7,3,5,-6,9])>>>jnp.ediff1d(a2)Array([ -3, 5, 3, -4, 2, -11, 15], dtype=int32)
