Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.argwhere

Contents

jax.numpy.argwhere#

jax.numpy.argwhere(a,*,size=None,fill_value=None)[source]#

Find the indices of nonzero array elements

JAX implementation ofnumpy.argwhere().

jnp.argwhere(x) is essentially equivalent tojnp.column_stack(jnp.nonzero(x))with special handling for zero-dimensional (i.e. scalar) inputs.

Because the size of the output ofargwhere is data-dependent, the function is nottypically compatible with JIT. The JAX version adds the optionalsize argument, whichspecifies the size of the leading dimension of the output - it must be specified staticallyforjnp.argwhere to be compiled with non-static operands. Seejax.numpy.nonzero()for a full discussion ofsize and its semantics.

Parameters:
  • a (ArrayLike) – array for which to find nonzero elements

  • size (int |None) – optional integer specifying statically the number of expected nonzero elements.This must be specified in order to useargwhere within JAX transformations likejax.jit(). Seejax.numpy.nonzero() for more information.

  • fill_value (ArrayLike |None) – optional array specifying the fill value whensize is specified.Seejax.numpy.nonzero() for more information.

Returns:

a two-dimensional array of shape[size,x.ndim]. Ifsize is not specified asan argument, it is equal to the number of nonzero elements inx.

Return type:

Array

Examples

Two-dimensional array:

>>>x=jnp.array([[1,0,2],...[0,3,0]])>>>jnp.argwhere(x)Array([[0, 0],       [0, 2],       [1, 1]], dtype=int32)

Equivalent computation usingjax.numpy.column_stack() andjax.numpy.nonzero():

>>>jnp.column_stack(jnp.nonzero(x))Array([[0, 0],       [0, 2],       [1, 1]], dtype=int32)

Special case for zero-dimensional (i.e. scalar) inputs:

>>>jnp.argwhere(1)Array([], shape=(1, 0), dtype=int32)>>>jnp.argwhere(0)Array([], shape=(0, 0), dtype=int32)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp