Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.lexsort

Contents

jax.numpy.lexsort#

jax.numpy.lexsort(keys,axis=-1)[source]#

Sort a sequence of keys in lexicographic order.

JAX implementation ofnumpy.lexsort().

Parameters:
Returns:

An array of integers of shapekeys[0].shape giving the indices of theentries in lexicographically-sorted order.

Return type:

Array

See also

Examples

lexsort() with a single key is equivalent toargsort():

>>>key1=jnp.array([4,2,3,2,5])>>>jnp.lexsort([key1])Array([1, 3, 2, 0, 4], dtype=int32)>>>jnp.argsort(key1)Array([1, 3, 2, 0, 4], dtype=int32)

With multiple keys,lexsort() uses the last key as the primary key:

>>>key2=jnp.array([2,1,1,2,2])>>>jnp.lexsort([key1,key2])Array([1, 2, 3, 0, 4], dtype=int32)

The meaning of the indices become more clear when printing the sorted keys:

>>>indices=jnp.lexsort([key1,key2])>>>print(f"{key1[indices]}\n{key2[indices]}")[2 3 2 4 5][1 1 2 2 2]

Notice that the elements ofkey2 appear in order, and within the sequencesof duplicated values the corresponding elements of`key1 appear in order.

For multi-dimensional inputs,lexsort() defaults to sorting along thelast axis:

>>>key1=jnp.array([[2,4,2,3],...[3,1,2,2]])>>>key2=jnp.array([[1,2,1,3],...[2,1,2,1]])>>>jnp.lexsort([key1,key2])Array([[0, 2, 1, 3],       [1, 3, 2, 0]], dtype=int32)

A different sort axis can be chosen using theaxis keyword; here we sortalong the leading axis:

>>>jnp.lexsort([key1,key2],axis=0)Array([[0, 1, 0, 1],       [1, 0, 1, 0]], dtype=int32)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp