Rate this Page

torch.argwhere#

torch.argwhere(input)Tensor#

Returns a tensor containing the indices of all non-zero elements ofinput. Each row in the result contains the indices of a non-zeroelement ininput. The result is sorted lexicographically, withthe last index changing the fastest (C-style).

Ifinput hasnn dimensions, then the resulting indices tensorout is of size(z×n)(z \times n), wherezz is the total number ofnon-zero elements in theinput tensor.

Note

This function is similar to NumPy’sargwhere.

Wheninput is on CUDA, this function causes host-device synchronization.

Parameters

{input}

Example:

>>>t=torch.tensor([1,0,1])>>>torch.argwhere(t)tensor([[0],        [2]])>>>t=torch.tensor([[1,0,1],[0,1,1]])>>>torch.argwhere(t)tensor([[0, 0],        [0, 2],        [1, 1],        [1, 2]])