Rate this Page

Flatten#

classtorch.nn.modules.flatten.Flatten(start_dim=1,end_dim=-1)[source]#

Flattens a contiguous range of dims into a tensor.

For use withSequential, seetorch.flatten() for details.

Shape:
Parameters
  • start_dim (int) – first dim to flatten (default = 1).

  • end_dim (int) – last dim to flatten (default = -1).

Examples::
>>>input=torch.randn(32,1,5,5)>>># With default parameters>>>m=nn.Flatten()>>>output=m(input)>>>output.size()torch.Size([32, 25])>>># With non-default parameters>>>m=nn.Flatten(0,2)>>>output=m(input)>>>output.size()torch.Size([160, 5])
extra_repr()[source]#

Returns the extra representation of the module.

Return type

str

forward(input)[source]#

Runs the forward pass.

Return type

Tensor