Rate this Page

torch.roll#

torch.roll(input,shifts,dims=None)Tensor#

Roll the tensorinput along the given dimension(s). Elements that areshifted beyond the last position are re-introduced at the first position. Ifdims isNone, the tensor will be flattened before rolling and thenrestored to the original shape.

Parameters
  • input (Tensor) – the input tensor.

  • shifts (int ortuple ofints) – The number of places by which the elementsof the tensor are shifted. If shifts is a tuple, dims must be a tuple ofthe same size, and each dimension will be rolled by the correspondingvalue

  • dims (int ortuple ofints) – Axis along which to roll

Example:

>>>x=torch.tensor([1,2,3,4,5,6,7,8]).view(4,2)>>>xtensor([[1, 2],        [3, 4],        [5, 6],        [7, 8]])>>>torch.roll(x,1)tensor([[8, 1],        [2, 3],        [4, 5],        [6, 7]])>>>torch.roll(x,1,0)tensor([[7, 8],        [1, 2],        [3, 4],        [5, 6]])>>>torch.roll(x,-1,0)tensor([[3, 4],        [5, 6],        [7, 8],        [1, 2]])>>>torch.roll(x,shifts=(2,1),dims=(0,1))tensor([[6, 5],        [8, 7],        [2, 1],        [4, 3]])