torch.argwhere#
- torch.argwhere(input)→Tensor#
Returns a tensor containing the indices of all non-zero elements of
input. Each row in the result contains the indices of a non-zeroelement ininput. The result is sorted lexicographically, withthe last index changing the fastest (C-style).If
inputhas dimensions, then the resulting indices tensoroutis of size, where is the total number ofnon-zero elements in theinputtensor.Note
This function is similar to NumPy’sargwhere.
When
inputis on CUDA, this function causes host-device synchronization.- Parameters
{input} –
Example:
>>>t=torch.tensor([1,0,1])>>>torch.argwhere(t)tensor([[0], [2]])>>>t=torch.tensor([[1,0,1],[0,1,1]])>>>torch.argwhere(t)tensor([[0, 0], [0, 2], [1, 1], [1, 2]])