jax.numpy.flatnonzero
Contents
jax.numpy.flatnonzero#
- jax.numpy.flatnonzero(a,*,size=None,fill_value=None)[source]#
Return indices of nonzero elements in a flattened array
JAX implementation of
numpy.flatnonzero().jnp.flatnonzero(x)is equivalent tononzero(ravel(a))[0]. For a fulldiscussion of the parameters to this function, refer tojax.numpy.nonzero().- Parameters:
a (ArrayLike) – N-dimensional array.
size (int |None) – optional static integer specifying the number of nonzero entries toreturn. See
jax.numpy.nonzero()for more discussion of this parameter.fill_value (None |ArrayLike |tuple[ArrayLike,...]) – optional padding value when
sizeis specified. Defaults to 0.Seejax.numpy.nonzero()for more discussion of this parameter.
- Returns:
Array containing the indices of each nonzero value in the flattened array.
- Return type:
Examples
>>>x=jnp.array([[0,5,0],...[6,0,8]])>>>jnp.flatnonzero(x)Array([1, 3, 5], dtype=int32)
This is equivalent to calling
nonzero()on the flattenedarray, and extracting the first entry in the resulting tuple:>>>jnp.nonzero(x.ravel())[0]Array([1, 3, 5], dtype=int32)
The returned indices can be used to extract nonzero entries from theflattened array:
>>>indices=jnp.flatnonzero(x)>>>x.ravel()[indices]Array([5, 6, 8], dtype=int32)
