Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.partition

Contents

jax.numpy.partition#

jax.numpy.partition(a,kth,axis=-1)[source]#

Returns a partially-sorted copy of an array.

JAX implementation ofnumpy.partition(). The JAX version differs fromNumPy in the treatment of NaN entries: NaNs which have the negative bit setare sorted to the beginning of the array.

Parameters:
  • a (Array |ndarray |bool |number |bool |int |float |complex |TypedNdArray) – array to be partitioned.

  • kth (int) – static integer index about which to partition the array.

  • axis (int) – static integer axis along which to partition the array; default is -1.

Returns:

A copy ofa partitioned at thekth value alongaxis. The entriesbeforekth are values smaller thantake(a,kth,axis), and entriesafterkth are indices of values larger thantake(a,kth,axis)

Return type:

Array

Note

The JAX version requires thekth argument to be a static integer rather thana general array. This is implemented via two calls tojax.lax.top_k(). Ifyou’re only accessing the top or bottom k values of the output, it may be moreefficient to calljax.lax.top_k() directly.

See also

Examples

>>>x=jnp.array([6,8,4,3,1,9,7,5,2,3])>>>kth=4>>>x_partitioned=jnp.partition(x,kth)>>>x_partitionedArray([1, 2, 3, 3, 4, 9, 8, 7, 6, 5], dtype=int32)

The result is a partially-sorted copy of the input. All values beforekthare of smaller than the pivot value, and all values afterkth are largerthan the pivot value:

>>>smallest_values=x_partitioned[:kth]>>>pivot_value=x_partitioned[kth]>>>largest_values=x_partitioned[kth+1:]>>>print(smallest_values,pivot_value,largest_values)[1 2 3 3] 4 [9 8 7 6 5]

Notice that amongsmallest_values andlargest_values, the returnedorder is arbitrary and implementation-dependent.

Contents

[8]ページ先頭

©2009-2025 Movatter.jp