Rate this Page

torch.trace#

torch.trace(input)Tensor#

Returns the sum of the elements of the diagonal of the input 2-D matrix.

Example:

>>>x=torch.arange(1.,10.).view(3,3)>>>xtensor([[ 1.,  2.,  3.],        [ 4.,  5.,  6.],        [ 7.,  8.,  9.]])>>>torch.trace(x)tensor(15.)