Rate this Page

torch.tensor_split#

torch.tensor_split(input,indices_or_sections,dim=0)ListofTensors#

Splits a tensor into multiple sub-tensors, all of which are views ofinput,along dimensiondim according to the indices or number of sections specifiedbyindices_or_sections. This function is based on NumPy’snumpy.array_split().

Parameters
  • input (Tensor) – the tensor to split

  • indices_or_sections (Tensor,int orlist ortuple ofints) –

    Ifindices_or_sections is an integern or a zero dimensional long tensorwith valuen,input is split inton sections along dimensiondim.Ifinput is divisible byn along dimensiondim, eachsection will be of equal size,input.size(dim)/n. Ifinputis not divisible byn, the sizes of the firstint(input.size(dim)%n)sections will have sizeint(input.size(dim)/n)+1, and the rest willhave sizeint(input.size(dim)/n).

    Ifindices_or_sections is a list or tuple of ints, or a one-dimensional longtensor, theninput is split along dimensiondim at each of the indicesin the list, tuple or tensor. For instance,indices_or_sections=[2,3] anddim=0would result in the tensorsinput[:2],input[2:3], andinput[3:].

    Ifindices_or_sections is a tensor, it must be a zero-dimensional or one-dimensionallong tensor on the CPU.

  • dim (int,optional) – dimension along which to split the tensor. Default:0

Example:

>>>x=torch.arange(8)>>>torch.tensor_split(x,3)(tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7]))>>>x=torch.arange(7)>>>torch.tensor_split(x,3)(tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6]))>>>torch.tensor_split(x,(1,6))(tensor([0]), tensor([1, 2, 3, 4, 5]), tensor([6]))>>>x=torch.arange(14).reshape(2,7)>>>xtensor([[ 0,  1,  2,  3,  4,  5,  6],        [ 7,  8,  9, 10, 11, 12, 13]])>>>torch.tensor_split(x,3,dim=1)(tensor([[0, 1, 2],        [7, 8, 9]]), tensor([[ 3,  4],        [10, 11]]), tensor([[ 5,  6],        [12, 13]]))>>>torch.tensor_split(x,(1,6),dim=1)(tensor([[0],        [7]]), tensor([[ 1,  2,  3,  4,  5],        [ 8,  9, 10, 11, 12]]), tensor([[ 6],        [13]]))