Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.roll

Contents

jax.numpy.roll#

jax.numpy.roll(a,shift,axis=None)[source]#

Roll the elements of an array along a specified axis.

JAX implementation ofnumpy.roll().

Parameters:
  • a (ArrayLike) – input array.

  • shift (ArrayLike |Sequence[int]) – the number of positions to shift the specified axis. If an integer,all axes are shifted by the same amount. If a tuple, the shift for eachaxis is specified individually.

  • axis (int |Sequence[int]|None) – the axis or axes to roll. IfNone, the array is flattened, shifted,and then reshaped to its original shape.

Returns:

A copy ofa with elements rolled along the specified axis or axes.

Return type:

Array

See also

Examples

>>>a=jnp.array([0,1,2,3,4,5])>>>jnp.roll(a,2)Array([4, 5, 0, 1, 2, 3], dtype=int32)

Roll elements along a specific axis:

>>>a=jnp.array([[0,1,2,3],...[4,5,6,7],...[8,9,10,11]])>>>jnp.roll(a,1,axis=0)Array([[ 8,  9, 10, 11],       [ 0,  1,  2,  3],       [ 4,  5,  6,  7]], dtype=int32)>>>jnp.roll(a,[2,3],axis=[0,1])Array([[ 5,  6,  7,  4],       [ 9, 10, 11,  8],       [ 1,  2,  3,  0]], dtype=int32)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp