Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.moveaxis

Contents

jax.numpy.moveaxis#

jax.numpy.moveaxis(a,source,destination)[source]#

Move an array axis to a new position

JAX implementation ofnumpy.moveaxis(), implemented in terms ofjax.lax.transpose().

Parameters:
  • a (ArrayLike) – input array

  • source (int |Sequence[int]) – index or indices of the axes to move.

  • destination (int |Sequence[int]) – index or indices of the axes destinations

Returns:

Copy ofa with axes moved fromsource todestination.

Return type:

Array

Notes

Unlikenumpy.moveaxis(),jax.numpy.moveaxis() will return a copy ratherthan a view of the input array. However, under JIT, the compiler will optimize awaysuch copies when possible, so this doesn’t have performance impacts in practice.

See also

Examples

>>>a=jnp.ones((2,3,4,5))

Move axis1 to the end of the array:

>>>jnp.moveaxis(a,1,-1).shape(2, 4, 5, 3)

Move the last axis to position 1:

>>>jnp.moveaxis(a,-1,1).shape(2, 5, 3, 4)

Move multiple axes:

>>>jnp.moveaxis(a,(0,1),(-1,-2)).shape(4, 5, 3, 2)

This can also be accomplished viatranspose():

>>>a.transpose(2,3,1,0).shape(4, 5, 3, 2)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp