jax.numpy.moveaxis
Contents
jax.numpy.moveaxis#
- jax.numpy.moveaxis(a,source,destination)[source]#
Move an array axis to a new position
JAX implementation of
numpy.moveaxis(), implemented in terms ofjax.lax.transpose().- Parameters:
- Returns:
Copy of
awith axes moved fromsourcetodestination.- Return type:
Notes
Unlike
numpy.moveaxis(),jax.numpy.moveaxis()will return a copy ratherthan a view of the input array. However, under JIT, the compiler will optimize awaysuch copies when possible, so this doesn’t have performance impacts in practice.See also
jax.numpy.swapaxes(): swap two axes.jax.numpy.rollaxis(): older API for moving an axis.jax.numpy.transpose(): general axes permutation.
Examples
>>>a=jnp.ones((2,3,4,5))
Move axis
1to the end of the array:>>>jnp.moveaxis(a,1,-1).shape(2, 4, 5, 3)
Move the last axis to position 1:
>>>jnp.moveaxis(a,-1,1).shape(2, 5, 3, 4)
Move multiple axes:
>>>jnp.moveaxis(a,(0,1),(-1,-2)).shape(4, 5, 3, 2)
This can also be accomplished via
transpose():>>>a.transpose(2,3,1,0).shape(4, 5, 3, 2)
Contents
