jax.numpy.sort
Contents
jax.numpy.sort#
- jax.numpy.sort(a,axis=-1,*,kind=None,order=None,stable=True,descending=False)[source]#
Return a sorted copy of an array.
JAX implementation of
numpy.sort().- Parameters:
a (Array |ndarray |bool |number |bool |int |float |complex |TypedNdArray) – array to sort
axis (int |None) – integer axis along which to sort. Defaults to
-1, i.e. the lastaxis. IfNone, thenais flattened before being sorted.stable (bool) – boolean specifying whether a stable sort should be used. Default=True.
descending (bool) – boolean specifying whether to sort in descending order. Default=False.
kind (None) – deprecated; instead specify sort algorithm using stable=True or stable=False.
order (None) – not supported by JAX
- Returns:
Sorted array of shape
a.shape(ifaxisis an integer) or of shape(a.size,)(ifaxisis None).- Return type:
Examples
Simple 1-dimensional sort
>>>x=jnp.array([1,3,5,4,2,1])>>>jnp.sort(x)Array([1, 1, 2, 3, 4, 5], dtype=int32)
Sort along the last axis of an array:
>>>x=jnp.array([[2,1,3],...[4,3,6]])>>>jnp.sort(x,axis=1)Array([[1, 2, 3], [3, 4, 6]], dtype=int32)
See also
jax.numpy.argsort(): return indices of sorted values.jax.numpy.lexsort(): lexicographical sort of multiple arrays.jax.lax.sort(): lower-level function wrapping XLA’s Sort operator.
