Rate this Page

torch.flatten#

torch.flatten(input,start_dim=0,end_dim=-1)Tensor#

Flattensinput by reshaping it into a one-dimensional tensor. Ifstart_dim orend_dimare passed, only dimensions starting withstart_dim and ending withend_dim are flattened.The order of elements ininput is unchanged.

Unlike NumPy’s flatten, which always copies input’s data, this function may return the original object, a view,or copy. If no dimensions are flattened, then the original objectinput is returned. Otherwise, if input canbe viewed as the flattened shape, then that view is returned. Finally, only if the input cannot be viewed as theflattened shape is input’s data copied. Seetorch.Tensor.view() for details on when a view will be returned.

Note

Flattening a zero-dimensional tensor will return a one-dimensional view.

Parameters
  • input (Tensor) – the input tensor.

  • start_dim (int) – the first dim to flatten

  • end_dim (int) – the last dim to flatten

Example:

>>>t=torch.tensor([[[1,2],...[3,4]],...[[5,6],...[7,8]]])>>>torch.flatten(t)tensor([1, 2, 3, 4, 5, 6, 7, 8])>>>torch.flatten(t,start_dim=1)tensor([[1, 2, 3, 4],        [5, 6, 7, 8]])