Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.take

Contents

jax.numpy.take#

jax.numpy.take(a,indices,axis=None,out=None,mode=None,unique_indices=False,indices_are_sorted=False,fill_value=None)[source]#

Take elements from an array.

JAX implementation ofnumpy.take(), implemented in terms ofjax.lax.gather(). JAX’s behavior differs from NumPy in the caseof out-of-bound indices; see themode parameter below.

Parameters:
  • a (Array |ndarray |bool |number |bool |int |float |complex |TypedNdArray) – array from which to take values.

  • indices (Array |ndarray |bool |number |bool |int |float |complex |TypedNdArray) – N-dimensional array of integer indices of values to take from the array.

  • axis (int |None) – the axis along which to take values. If not specified, the array willbe flattened before indexing is applied.

  • mode (str |None) – Out-of-bounds indexing mode, either"fill" or"clip". The defaultmode="fill" returns invalid values (e.g. NaN) for out-of bounds indices;thefill_value argument gives control over this value. For more discussionofmode options, seejax.numpy.ndarray.at.

  • fill_value (bool |number |bool |int |float |complex |None) – The fill value to return for out-of-bounds slices when mode is ‘fill’.Ignored otherwise. Defaults to NaN for inexact types, the largest negative value forsigned types, the largest positive value for unsigned types, and True for booleans.

  • unique_indices (bool) – If True, the implementation will assume that the indices are unique,which can result in more efficient execution on some backends. If set to True andindices are not unique, the output is undefined.

  • indices_are_sorted (bool) – If True, the implementation will assume that the indices aresorted in ascending order, which can lead to more efficient execution on somebackends. If set to True and indices are not sorted, the output is undefined.

  • out (None)

Returns:

Array of values extracted froma.

Return type:

Array

See also

Examples

>>>x=jnp.array([[1.,2.,3.],...[4.,5.,6.]])>>>indices=jnp.array([2,0])

Passing no axis results in indexing into the flattened array:

>>>jnp.take(x,indices)Array([3., 1.], dtype=float32)>>>x.ravel()[indices]# equivalent indexing syntaxArray([3., 1.], dtype=float32)

Passing an axis results ind applying the index to every subarray along the axis:

>>>jnp.take(x,indices,axis=1)Array([[3., 1.],       [6., 4.]], dtype=float32)>>>x[:,indices]# equivalent indexing syntaxArray([[3., 1.],       [6., 4.]], dtype=float32)

Out-of-bound indices fill with invalid values. For float inputs, this isNaN:

>>>jnp.take(x,indices,axis=0)Array([[nan, nan, nan],       [ 1.,  2.,  3.]], dtype=float32)>>>x.at[indices].get(mode='fill',fill_value=jnp.nan)# equivalent indexing syntaxArray([[nan, nan, nan],       [ 1.,  2.,  3.]], dtype=float32)

This default out-of-bound behavior can be adjusted using themode parameter, forexample, we can instead clip to the last valid value:

>>>jnp.take(x,indices,axis=0,mode='clip')Array([[4., 5., 6.],       [1., 2., 3.]], dtype=float32)>>>x.at[indices].get(mode='clip')# equivalent indexing syntaxArray([[4., 5., 6.],       [1., 2., 3.]], dtype=float32)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp