jax.numpy.lexsort
Contents
jax.numpy.lexsort#
- jax.numpy.lexsort(keys,axis=-1)[source]#
Sort a sequence of keys in lexicographic order.
JAX implementation of
numpy.lexsort().- Parameters:
- Returns:
An array of integers of shape
keys[0].shapegiving the indices of theentries in lexicographically-sorted order.- Return type:
See also
jax.numpy.argsort(): sort a single entry by index.jax.lax.sort(): direct XLA sorting API.
Examples
lexsort()with a single key is equivalent toargsort():>>>key1=jnp.array([4,2,3,2,5])>>>jnp.lexsort([key1])Array([1, 3, 2, 0, 4], dtype=int32)>>>jnp.argsort(key1)Array([1, 3, 2, 0, 4], dtype=int32)
With multiple keys,
lexsort()uses the last key as the primary key:>>>key2=jnp.array([2,1,1,2,2])>>>jnp.lexsort([key1,key2])Array([1, 2, 3, 0, 4], dtype=int32)
The meaning of the indices become more clear when printing the sorted keys:
>>>indices=jnp.lexsort([key1,key2])>>>print(f"{key1[indices]}\n{key2[indices]}")[2 3 2 4 5][1 1 2 2 2]
Notice that the elements of
key2appear in order, and within the sequencesof duplicated values the corresponding elements of`key1appear in order.For multi-dimensional inputs,
lexsort()defaults to sorting along thelast axis:>>>key1=jnp.array([[2,4,2,3],...[3,1,2,2]])>>>key2=jnp.array([[1,2,1,3],...[2,1,2,1]])>>>jnp.lexsort([key1,key2])Array([[0, 2, 1, 3], [1, 3, 2, 0]], dtype=int32)
A different sort axis can be chosen using the
axiskeyword; here we sortalong the leading axis:>>>jnp.lexsort([key1,key2],axis=0)Array([[0, 1, 0, 1], [1, 0, 1, 0]], dtype=int32)
