Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.sort

Contents

jax.numpy.sort#

jax.numpy.sort(a,axis=-1,*,kind=None,order=None,stable=True,descending=False)[source]#

Return a sorted copy of an array.

JAX implementation ofnumpy.sort().

Parameters:
  • a (Array |ndarray |bool |number |bool |int |float |complex |TypedNdArray) – array to sort

  • axis (int |None) – integer axis along which to sort. Defaults to-1, i.e. the lastaxis. IfNone, thena is flattened before being sorted.

  • stable (bool) – boolean specifying whether a stable sort should be used. Default=True.

  • descending (bool) – boolean specifying whether to sort in descending order. Default=False.

  • kind (None) – deprecated; instead specify sort algorithm using stable=True or stable=False.

  • order (None) – not supported by JAX

Returns:

Sorted array of shapea.shape (ifaxis is an integer) or of shape(a.size,) (ifaxis is None).

Return type:

Array

Examples

Simple 1-dimensional sort

>>>x=jnp.array([1,3,5,4,2,1])>>>jnp.sort(x)Array([1, 1, 2, 3, 4, 5], dtype=int32)

Sort along the last axis of an array:

>>>x=jnp.array([[2,1,3],...[4,3,6]])>>>jnp.sort(x,axis=1)Array([[1, 2, 3],       [3, 4, 6]], dtype=int32)

See also

Contents

[8]ページ先頭

©2009-2025 Movatter.jp