jax.numpy.unravel_index
Contents
jax.numpy.unravel_index#
- jax.numpy.unravel_index(indices,shape)[source]#
Convert flat indices into multi-dimensional indices.
JAX implementation of
numpy.unravel_index(). The JAX version differs inits treatment of out-of-bound indices: unlike NumPy, negative indices aresupported, and out-of-bound indices are clipped to the nearest valid value.- Parameters:
indices (ArrayLike) – integer array of flat indices
shape (Shape) – shape of multidimensional array to index into
- Returns:
Tuple of unraveled indices
- Return type:
See also
jax.numpy.ravel_multi_index(): Inverse of this function.Examples
Start with a 1D array values and indices:
>>>x=jnp.array([2.,3.,4.,5.,6.,7.])>>>indices=jnp.array([1,3,5])>>>print(x[indices])[3. 5. 7.]
Now if
xis reshaped,unravel_indicescan be used to convertthe flat indices into a tuple of indices that access the same entries:>>>shape=(2,3)>>>x_2D=x.reshape(shape)>>>indices_2D=jnp.unravel_index(indices,shape)>>>indices_2D(Array([0, 1, 1], dtype=int32), Array([1, 0, 2], dtype=int32))>>>print(x_2D[indices_2D])[3. 5. 7.]
The inverse function,
ravel_multi_index, can be used to obtain theoriginal indices:>>>jnp.ravel_multi_index(indices_2D,shape)Array([1, 3, 5], dtype=int32)
Contents
