Rate this Page

torch.split#

torch.split(tensor,split_size_or_sections,dim=0)[source]#

Splits the tensor into chunks. Each chunk is a view of the original tensor.

Ifsplit_size_or_sections is an integer type, thentensor willbe split into equally sized chunks (if possible). Last chunk will be smaller ifthe tensor size along the given dimensiondim is not divisible bysplit_size.

Ifsplit_size_or_sections is a list, thentensor will be splitintolen(split_size_or_sections) chunks with sizes indim accordingtosplit_size_or_sections.

Parameters:
  • tensor (Tensor) – tensor to split.

  • split_size_or_sections (int) or(list(int)) – size of a single chunk orlist of sizes for each chunk

  • dim (int) – dimension along which to split the tensor.

Return type:

tuple[Tensor, …]

Example:

>>>a=torch.arange(10).reshape(5,2)>>>atensor([[0, 1],        [2, 3],        [4, 5],        [6, 7],        [8, 9]])>>>torch.split(a,2)(tensor([[0, 1],         [2, 3]]), tensor([[4, 5],         [6, 7]]), tensor([[8, 9]]))>>>torch.split(a,[1,4])(tensor([[0, 1]]), tensor([[2, 3],         [4, 5],         [6, 7],         [8, 9]]))