Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.nonzero

Contents

jax.numpy.nonzero#

jax.numpy.nonzero(a,*,size=None,fill_value=None)[source]#

Return indices of nonzero elements of an array.

JAX implementation ofnumpy.nonzero().

Because the size of the output ofnonzero is data-dependent, the functionis not compatible with JIT and other transformations. The JAX version addsthe optionalsize argument which must be specified statically forjnp.nonzero to be used within JAX’s transformations.

Parameters:
  • a (ArrayLike) – N-dimensional array.

  • size (int |None) – optional static integer specifying the number of nonzero entries toreturn. If there are more nonzero elements than the specifiedsize,then indices will be truncated at the end. If there are fewer nonzeroelements than the specified size, then indices will be padded withfill_value, which defaults to zero.

  • fill_value (None |ArrayLike |tuple[ArrayLike,...]) – optional padding value whensize is specified. Defaults to 0.

Returns:

Tuple of JAX Arrays of lengtha.ndim, containing the indices of eachnonzero value.

Return type:

tuple[Array, …]

Examples

One-dimensional array returns a length-1 tuple of indices:

>>>x=jnp.array([0,5,0,6,0,7])>>>jnp.nonzero(x)(Array([1, 3, 5], dtype=int32),)

Two-dimensional array returns a length-2 tuple of indices:

>>>x=jnp.array([[0,5,0],...[6,0,7]])>>>jnp.nonzero(x)(Array([0, 1, 1], dtype=int32), Array([1, 0, 2], dtype=int32))

In either case, the resulting tuple of indices can be used directly to extractthe nonzero values:

>>>indices=jnp.nonzero(x)>>>x[indices]Array([5, 6, 7], dtype=int32)

The output ofnonzero has a dynamic shape, because the number of returnedindices depends on the contents of the input array. As such, it is incompatiblewith JIT and other JAX transformations:

>>>x=jnp.array([0,5,0,6,0,7])>>>jax.jit(jnp.nonzero)(x)Traceback (most recent call last):...ConcretizationTypeError:Abstract tracer value encountered where concrete value is expected: traced array with shape int32[].The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.

This can be addressed by passing a staticsize parameter to specify thedesired output shape:

>>>nonzero_jit=jax.jit(jnp.nonzero,static_argnames='size')>>>nonzero_jit(x,size=3)(Array([1, 3, 5], dtype=int32),)

Ifsize does not match the true size, the result will be either truncated or padded:

>>>nonzero_jit(x,size=2)# size < 3: indices are truncated(Array([1, 3], dtype=int32),)>>>nonzero_jit(x,size=5)# size > 3: indices are padded with zeros.(Array([1, 3, 5, 0, 0], dtype=int32),)

You can specify a custom fill value for the padding using thefill_value argument:

>>>nonzero_jit(x,size=5,fill_value=len(x))(Array([1, 3, 5, 6, 6], dtype=int32),)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp