Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.unravel_index

jax.numpy.unravel_index#

jax.numpy.unravel_index(indices,shape)[source]#

Convert flat indices into multi-dimensional indices.

JAX implementation ofnumpy.unravel_index(). The JAX version differs inits treatment of out-of-bound indices: unlike NumPy, negative indices aresupported, and out-of-bound indices are clipped to the nearest valid value.

Parameters:
  • indices (ArrayLike) – integer array of flat indices

  • shape (Shape) – shape of multidimensional array to index into

Returns:

Tuple of unraveled indices

Return type:

tuple[Array, …]

See also

jax.numpy.ravel_multi_index(): Inverse of this function.

Examples

Start with a 1D array values and indices:

>>>x=jnp.array([2.,3.,4.,5.,6.,7.])>>>indices=jnp.array([1,3,5])>>>print(x[indices])[3. 5. 7.]

Now ifx is reshaped,unravel_indices can be used to convertthe flat indices into a tuple of indices that access the same entries:

>>>shape=(2,3)>>>x_2D=x.reshape(shape)>>>indices_2D=jnp.unravel_index(indices,shape)>>>indices_2D(Array([0, 1, 1], dtype=int32), Array([1, 0, 2], dtype=int32))>>>print(x_2D[indices_2D])[3. 5. 7.]

The inverse function,ravel_multi_index, can be used to obtain theoriginal indices:

>>>jnp.ravel_multi_index(indices_2D,shape)Array([1, 3, 5], dtype=int32)

[8]ページ先頭

©2009-2025 Movatter.jp