Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.argpartition

jax.numpy.argpartition#

jax.numpy.argpartition(a,kth,axis=-1)[source]#

Returns indices that partially sort an array.

JAX implementation ofnumpy.argpartition(). The JAX version differs fromNumPy in the treatment of NaN entries: NaNs which have the negative bit set aresorted to the beginning of the array.

Parameters:
  • a (Array |ndarray |bool |number |bool |int |float |complex |TypedNdArray) – array to be partitioned.

  • kth (int) – static integer index about which to partition the array.

  • axis (int) – static integer axis along which to partition the array; default is -1.

Returns:

Indices which partitiona at thekth value alongaxis. The entriesbeforekth are indices of values smaller thantake(a,kth,axis), andentries afterkth are indices of values larger thantake(a,kth,axis)

Return type:

Array

Note

The JAX version requires thekth argument to be a static integer rather thana general array. This is implemented via two calls tojax.lax.top_k(). Ifyou’re only accessing the top or bottom k values of the output, it may be moreefficient to calljax.lax.top_k() directly.

See also

Examples

>>>x=jnp.array([6,8,4,3,1,9,7,5,2,3])>>>kth=4>>>idx=jnp.argpartition(x,kth)>>>idxArray([4, 8, 3, 9, 2, 0, 1, 5, 6, 7], dtype=int32)

The result is a sequence of indices that partially sort the input. All indicesbeforekth are of values smaller than the pivot value, and all indicesafterkth are of values larger than the pivot value:

>>>x_partitioned=x[idx]>>>smallest_values=x_partitioned[:kth]>>>pivot_value=x_partitioned[kth]>>>largest_values=x_partitioned[kth+1:]>>>print(smallest_values,pivot_value,largest_values)[1 2 3 3] 4 [6 8 9 7 5]

Notice that amongsmallest_values andlargest_values, the returnedorder is arbitrary and implementation-dependent.


[8]ページ先頭

©2009-2025 Movatter.jp