jax.numpy.rollaxis
Contents
jax.numpy.rollaxis#
- jax.numpy.rollaxis(a,axis,start=0)[source]#
Roll the specified axis to a given position.
JAX implementation of
numpy.rollaxis().This function exists for compatibility with NumPy, but in most cases the newer
jax.numpy.moveaxis()instead, because the meaning of its arguments ismore intuitive.- Parameters:
- Returns:
Copy of
awith rolled axis.- Return type:
Notes
Unlike
numpy.rollaxis(),jax.numpy.rollaxis()will return a copy ratherthan a view of the input array. However, under JIT, the compiler will optimize awaysuch copies when possible, so this doesn’t have performance impacts in practice.See also
jax.numpy.moveaxis(): newer API with clearer semantics thanrollaxis;this should be preferred torollaxisin most cases.jax.numpy.swapaxes(): swap two axes.jax.numpy.transpose(): general permutation of axes.
Examples
>>>a=jnp.ones((2,3,4,5))
Roll axis 2 to the start of the array:
>>>jnp.rollaxis(a,2).shape(4, 2, 3, 5)
Roll axis 1 to the end of the array:
>>>jnp.rollaxis(a,1,a.ndim).shape(2, 4, 5, 3)
Equivalent of these two with
moveaxis()>>>jnp.moveaxis(a,2,0).shape(4, 2, 3, 5)>>>jnp.moveaxis(a,1,-1).shape(2, 4, 5, 3)
