jax.numpy.argwhere
Contents
jax.numpy.argwhere#
- jax.numpy.argwhere(a,*,size=None,fill_value=None)[source]#
Find the indices of nonzero array elements
JAX implementation of
numpy.argwhere().jnp.argwhere(x)is essentially equivalent tojnp.column_stack(jnp.nonzero(x))with special handling for zero-dimensional (i.e. scalar) inputs.Because the size of the output of
argwhereis data-dependent, the function is nottypically compatible with JIT. The JAX version adds the optionalsizeargument, whichspecifies the size of the leading dimension of the output - it must be specified staticallyforjnp.argwhereto be compiled with non-static operands. Seejax.numpy.nonzero()for a full discussion ofsizeand its semantics.- Parameters:
a (ArrayLike) – array for which to find nonzero elements
size (int |None) – optional integer specifying statically the number of expected nonzero elements.This must be specified in order to use
argwherewithin JAX transformations likejax.jit(). Seejax.numpy.nonzero()for more information.fill_value (ArrayLike |None) – optional array specifying the fill value when
sizeis specified.Seejax.numpy.nonzero()for more information.
- Returns:
a two-dimensional array of shape
[size,x.ndim]. Ifsizeis not specified asan argument, it is equal to the number of nonzero elements inx.- Return type:
Examples
Two-dimensional array:
>>>x=jnp.array([[1,0,2],...[0,3,0]])>>>jnp.argwhere(x)Array([[0, 0], [0, 2], [1, 1]], dtype=int32)
Equivalent computation using
jax.numpy.column_stack()andjax.numpy.nonzero():>>>jnp.column_stack(jnp.nonzero(x))Array([[0, 0], [0, 2], [1, 1]], dtype=int32)
Special case for zero-dimensional (i.e. scalar) inputs:
>>>jnp.argwhere(1)Array([], shape=(1, 0), dtype=int32)>>>jnp.argwhere(0)Array([], shape=(0, 0), dtype=int32)
