jax.numpy.swapaxes
Contents
jax.numpy.swapaxes#
- jax.numpy.swapaxes(a,axis1,axis2)[source]#
Swap two axes of an array.
JAX implementation of
numpy.swapaxes(), implemented in terms ofjax.lax.transpose().- Parameters:
- Returns:
Copy of
awith specified axes swapped.- Return type:
Notes
Unlike
numpy.swapaxes(),jax.numpy.swapaxes()will return a copy ratherthan a view of the input array. However, under JIT, the compiler will optimize awaysuch copies when possible, so this doesn’t have performance impacts in practice.See also
jax.numpy.moveaxis(): move a single axis of an array.jax.numpy.rollaxis(): older API formoveaxis.jax.lax.transpose(): more general axes permutations.jax.Array.swapaxes(): same functionality via an array method.
Examples
>>>a=jnp.ones((2,3,4,5))>>>jnp.swapaxes(a,1,3).shape(2, 5, 4, 3)
Equivalent output via the
swapaxesarray method:>>>a.swapaxes(1,3).shape(2, 5, 4, 3)
Equivalent output via
transpose():>>>a.transpose(0,3,2,1).shape(2, 5, 4, 3)
Contents
