torch.where#
- torch.where(condition,input,other,*,out=None)→Tensor#
Return a tensor of elements selected from either
inputorother, depending oncondition.The operation is defined as:
Note
The tensors
condition,input,othermust bebroadcastable.- Parameters
- Keyword Arguments
out (Tensor,optional) – the output tensor.
- Returns
A tensor of shape equal to the broadcasted shape of
condition,input,other- Return type
Example:
>>>x=torch.randn(3,2)>>>y=torch.ones(3,2)>>>xtensor([[-0.4620, 0.3139], [ 0.3898, -0.7197], [ 0.0478, -0.1657]])>>>torch.where(x>0,1.0,0.0)tensor([[0., 1.], [1., 0.], [1., 0.]])>>>torch.where(x>0,x,y)tensor([[ 1.0000, 0.3139], [ 0.3898, 1.0000], [ 0.0478, 1.0000]])>>>x=torch.randn(2,2,dtype=torch.double)>>>xtensor([[ 1.0779, 0.0383], [-0.8785, -1.1089]], dtype=torch.float64)>>>torch.where(x>0,x,0.)tensor([[1.0779, 0.0383], [0.0000, 0.0000]], dtype=torch.float64)
- torch.where(condition)→tupleofLongTensor
torch.where(condition)is identical totorch.nonzero(condition,as_tuple=True).Note
See also
torch.nonzero().