jax.numpy.roll
Contents
jax.numpy.roll#
- jax.numpy.roll(a,shift,axis=None)[source]#
Roll the elements of an array along a specified axis.
JAX implementation of
numpy.roll().- Parameters:
a (ArrayLike) – input array.
shift (ArrayLike |Sequence[int]) – the number of positions to shift the specified axis. If an integer,all axes are shifted by the same amount. If a tuple, the shift for eachaxis is specified individually.
axis (int |Sequence[int]|None) – the axis or axes to roll. If
None, the array is flattened, shifted,and then reshaped to its original shape.
- Returns:
A copy of
awith elements rolled along the specified axis or axes.- Return type:
See also
jax.numpy.rollaxis(): roll the specified axis to a given position.
Examples
>>>a=jnp.array([0,1,2,3,4,5])>>>jnp.roll(a,2)Array([4, 5, 0, 1, 2, 3], dtype=int32)
Roll elements along a specific axis:
>>>a=jnp.array([[0,1,2,3],...[4,5,6,7],...[8,9,10,11]])>>>jnp.roll(a,1,axis=0)Array([[ 8, 9, 10, 11], [ 0, 1, 2, 3], [ 4, 5, 6, 7]], dtype=int32)>>>jnp.roll(a,[2,3],axis=[0,1])Array([[ 5, 6, 7, 4], [ 9, 10, 11, 8], [ 1, 2, 3, 0]], dtype=int32)
Contents
