jax.numpy.partition
Contents
jax.numpy.partition#
- jax.numpy.partition(a,kth,axis=-1)[source]#
Returns a partially-sorted copy of an array.
JAX implementation of
numpy.partition(). The JAX version differs fromNumPy in the treatment of NaN entries: NaNs which have the negative bit setare sorted to the beginning of the array.- Parameters:
- Returns:
A copy of
apartitioned at thekthvalue alongaxis. The entriesbeforekthare values smaller thantake(a,kth,axis), and entriesafterkthare indices of values larger thantake(a,kth,axis)- Return type:
Note
The JAX version requires the
kthargument to be a static integer rather thana general array. This is implemented via two calls tojax.lax.top_k(). Ifyou’re only accessing the top or bottom k values of the output, it may be moreefficient to calljax.lax.top_k()directly.See also
jax.numpy.sort(): full sortjax.numpy.argpartition(): indirect partial sortjax.lax.top_k(): directly find the top k entriesjax.lax.approx_max_k(): compute the approximate top k entriesjax.lax.approx_min_k(): compute the approximate bottom k entries
Examples
>>>x=jnp.array([6,8,4,3,1,9,7,5,2,3])>>>kth=4>>>x_partitioned=jnp.partition(x,kth)>>>x_partitionedArray([1, 2, 3, 3, 4, 9, 8, 7, 6, 5], dtype=int32)
The result is a partially-sorted copy of the input. All values before
kthare of smaller than the pivot value, and all values afterkthare largerthan the pivot value:>>>smallest_values=x_partitioned[:kth]>>>pivot_value=x_partitioned[kth]>>>largest_values=x_partitioned[kth+1:]>>>print(smallest_values,pivot_value,largest_values)[1 2 3 3] 4 [9 8 7 6 5]
Notice that among
smallest_valuesandlargest_values, the returnedorder is arbitrary and implementation-dependent.
