torch.tensor_split#
- torch.tensor_split(input,indices_or_sections,dim=0)→ListofTensors#
Splits a tensor into multiple sub-tensors, all of which are views of
input,along dimensiondimaccording to the indices or number of sections specifiedbyindices_or_sections. This function is based on NumPy’snumpy.array_split().- Parameters
input (Tensor) – the tensor to split
indices_or_sections (Tensor,int orlist ortuple ofints) –
If
indices_or_sectionsis an integernor a zero dimensional long tensorwith valuen,inputis split intonsections along dimensiondim.Ifinputis divisible bynalong dimensiondim, eachsection will be of equal size,input.size(dim)/n. Ifinputis not divisible byn, the sizes of the firstint(input.size(dim)%n)sections will have sizeint(input.size(dim)/n)+1, and the rest willhave sizeint(input.size(dim)/n).If
indices_or_sectionsis a list or tuple of ints, or a one-dimensional longtensor, theninputis split along dimensiondimat each of the indicesin the list, tuple or tensor. For instance,indices_or_sections=[2,3]anddim=0would result in the tensorsinput[:2],input[2:3], andinput[3:].If
indices_or_sectionsis a tensor, it must be a zero-dimensional or one-dimensionallong tensor on the CPU.dim (int,optional) – dimension along which to split the tensor. Default:
0
Example:
>>>x=torch.arange(8)>>>torch.tensor_split(x,3)(tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7]))>>>x=torch.arange(7)>>>torch.tensor_split(x,3)(tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6]))>>>torch.tensor_split(x,(1,6))(tensor([0]), tensor([1, 2, 3, 4, 5]), tensor([6]))>>>x=torch.arange(14).reshape(2,7)>>>xtensor([[ 0, 1, 2, 3, 4, 5, 6], [ 7, 8, 9, 10, 11, 12, 13]])>>>torch.tensor_split(x,3,dim=1)(tensor([[0, 1, 2], [7, 8, 9]]), tensor([[ 3, 4], [10, 11]]), tensor([[ 5, 6], [12, 13]]))>>>torch.tensor_split(x,(1,6),dim=1)(tensor([[0], [7]]), tensor([[ 1, 2, 3, 4, 5], [ 8, 9, 10, 11, 12]]), tensor([[ 6], [13]]))