jax.numpy.argpartition
Contents
jax.numpy.argpartition#
- jax.numpy.argpartition(a,kth,axis=-1)[source]#
Returns indices that partially sort an array.
JAX implementation of
numpy.argpartition(). The JAX version differs fromNumPy in the treatment of NaN entries: NaNs which have the negative bit set aresorted to the beginning of the array.- Parameters:
- Returns:
Indices which partition
aat thekthvalue alongaxis. The entriesbeforekthare indices of values smaller thantake(a,kth,axis), andentries afterkthare indices of values larger thantake(a,kth,axis)- Return type:
Note
The JAX version requires the
kthargument to be a static integer rather thana general array. This is implemented via two calls tojax.lax.top_k(). Ifyou’re only accessing the top or bottom k values of the output, it may be moreefficient to calljax.lax.top_k()directly.See also
jax.numpy.partition(): direct partial sortjax.numpy.argsort(): full indirect sortjax.lax.top_k(): directly find the top k entriesjax.lax.approx_max_k(): compute the approximate top k entriesjax.lax.approx_min_k(): compute the approximate bottom k entries
Examples
>>>x=jnp.array([6,8,4,3,1,9,7,5,2,3])>>>kth=4>>>idx=jnp.argpartition(x,kth)>>>idxArray([4, 8, 3, 9, 2, 0, 1, 5, 6, 7], dtype=int32)
The result is a sequence of indices that partially sort the input. All indicesbefore
kthare of values smaller than the pivot value, and all indicesafterkthare of values larger than the pivot value:>>>x_partitioned=x[idx]>>>smallest_values=x_partitioned[:kth]>>>pivot_value=x_partitioned[kth]>>>largest_values=x_partitioned[kth+1:]>>>print(smallest_values,pivot_value,largest_values)[1 2 3 3] 4 [6 8 9 7 5]
Notice that among
smallest_valuesandlargest_values, the returnedorder is arbitrary and implementation-dependent.
