Rate this Page

torch.unflatten#

torch.unflatten(input,dim,sizes)Tensor#

Expands a dimension of the input tensor over multiple dimensions.

See also

torch.flatten() the inverse of this function. It coalesces several dimensions into one.

Parameters:
  • input (Tensor) – the input tensor.

  • dim (int) – Dimension to be unflattened, specified as an index intoinput.shape.

  • sizes (Tuple[int]) – New shape of the unflattened dimension.One of its elements can be-1 in which case the corresponding outputdimension is inferred. Otherwise, the product ofsizesmustequalinput.shape[dim].

Returns:

A View of input with the specified dimension unflattened.

Examples::
>>>torch.unflatten(torch.randn(3,4,1),1,(2,2)).shapetorch.Size([3, 2, 2, 1])>>>torch.unflatten(torch.randn(3,4,1),1,(-1,2)).shapetorch.Size([3, 2, 2, 1])>>>torch.unflatten(torch.randn(5,12,3),-2,(2,2,3,1,1)).shapetorch.Size([5, 2, 2, 3, 1, 1, 3])