jax.numpy.nonzero
Contents
jax.numpy.nonzero#
- jax.numpy.nonzero(a,*,size=None,fill_value=None)[source]#
Return indices of nonzero elements of an array.
JAX implementation of
numpy.nonzero().Because the size of the output of
nonzerois data-dependent, the functionis not compatible with JIT and other transformations. The JAX version addsthe optionalsizeargument which must be specified statically forjnp.nonzeroto be used within JAX’s transformations.- Parameters:
a (ArrayLike) – N-dimensional array.
size (int |None) – optional static integer specifying the number of nonzero entries toreturn. If there are more nonzero elements than the specified
size,then indices will be truncated at the end. If there are fewer nonzeroelements than the specified size, then indices will be padded withfill_value, which defaults to zero.fill_value (None |ArrayLike |tuple[ArrayLike,...]) – optional padding value when
sizeis specified. Defaults to 0.
- Returns:
Tuple of JAX Arrays of length
a.ndim, containing the indices of eachnonzero value.- Return type:
Examples
One-dimensional array returns a length-1 tuple of indices:
>>>x=jnp.array([0,5,0,6,0,7])>>>jnp.nonzero(x)(Array([1, 3, 5], dtype=int32),)
Two-dimensional array returns a length-2 tuple of indices:
>>>x=jnp.array([[0,5,0],...[6,0,7]])>>>jnp.nonzero(x)(Array([0, 1, 1], dtype=int32), Array([1, 0, 2], dtype=int32))
In either case, the resulting tuple of indices can be used directly to extractthe nonzero values:
>>>indices=jnp.nonzero(x)>>>x[indices]Array([5, 6, 7], dtype=int32)
The output of
nonzerohas a dynamic shape, because the number of returnedindices depends on the contents of the input array. As such, it is incompatiblewith JIT and other JAX transformations:>>>x=jnp.array([0,5,0,6,0,7])>>>jax.jit(jnp.nonzero)(x)Traceback (most recent call last):...ConcretizationTypeError:Abstract tracer value encountered where concrete value is expected: traced array with shape int32[].The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.
This can be addressed by passing a static
sizeparameter to specify thedesired output shape:>>>nonzero_jit=jax.jit(jnp.nonzero,static_argnames='size')>>>nonzero_jit(x,size=3)(Array([1, 3, 5], dtype=int32),)
If
sizedoes not match the true size, the result will be either truncated or padded:>>>nonzero_jit(x,size=2)# size < 3: indices are truncated(Array([1, 3], dtype=int32),)>>>nonzero_jit(x,size=5)# size > 3: indices are padded with zeros.(Array([1, 3, 5, 0, 0], dtype=int32),)
You can specify a custom fill value for the padding using the
fill_valueargument:>>>nonzero_jit(x,size=5,fill_value=len(x))(Array([1, 3, 5, 6, 6], dtype=int32),)
