Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.rollaxis

Contents

jax.numpy.rollaxis#

jax.numpy.rollaxis(a,axis,start=0)[source]#

Roll the specified axis to a given position.

JAX implementation ofnumpy.rollaxis().

This function exists for compatibility with NumPy, but in most cases the newerjax.numpy.moveaxis() instead, because the meaning of its arguments ismore intuitive.

Parameters:
  • a (ArrayLike) – input array.

  • axis (int) – index of the axis to roll forward.

  • start (int) – index toward which the axis will be rolled (default = 0). Afternormalizing negative axes, ifstart<=axis, the axis is rolled tothestart index; ifstart>axis, the axis is rolled until theposition beforestart.

Returns:

Copy ofa with rolled axis.

Return type:

Array

Notes

Unlikenumpy.rollaxis(),jax.numpy.rollaxis() will return a copy ratherthan a view of the input array. However, under JIT, the compiler will optimize awaysuch copies when possible, so this doesn’t have performance impacts in practice.

See also

Examples

>>>a=jnp.ones((2,3,4,5))

Roll axis 2 to the start of the array:

>>>jnp.rollaxis(a,2).shape(4, 2, 3, 5)

Roll axis 1 to the end of the array:

>>>jnp.rollaxis(a,1,a.ndim).shape(2, 4, 5, 3)

Equivalent of these two withmoveaxis()

>>>jnp.moveaxis(a,2,0).shape(4, 2, 3, 5)>>>jnp.moveaxis(a,1,-1).shape(2, 4, 5, 3)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp