Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.swapaxes

Contents

jax.numpy.swapaxes#

jax.numpy.swapaxes(a,axis1,axis2)[source]#

Swap two axes of an array.

JAX implementation ofnumpy.swapaxes(), implemented in terms ofjax.lax.transpose().

Parameters:
  • a (ArrayLike) – input array

  • axis1 (int) – index of first axis

  • axis2 (int) – index of second axis

Returns:

Copy ofa with specified axes swapped.

Return type:

Array

Notes

Unlikenumpy.swapaxes(),jax.numpy.swapaxes() will return a copy ratherthan a view of the input array. However, under JIT, the compiler will optimize awaysuch copies when possible, so this doesn’t have performance impacts in practice.

See also

Examples

>>>a=jnp.ones((2,3,4,5))>>>jnp.swapaxes(a,1,3).shape(2, 5, 4, 3)

Equivalent output via theswapaxes array method:

>>>a.swapaxes(1,3).shape(2, 5, 4, 3)

Equivalent output viatranspose():

>>>a.transpose(0,3,2,1).shape(2, 5, 4, 3)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp