Reputation: 40
environment:
google colab
Python 3.7.11
torch 1.9.0+cu102
code and output
import torch
b = torch.tensor([[1,1,1],[4,5,6]])
print(b.T)
print(torch.tensor([1,4]) in b.T) #
print(torch.tensor([2,1]) in b.T) #
print(torch.tensor([1,2]) in b.T) # Not as expected
print(torch.tensor([2,5]) in b.T) # Not as expected
----------------------------------------------------------
tensor([[1, 4],
[1, 5],
[1, 6]])
True
False
True
True
probleam
I want to judge whether one tensor is in another.But the result above is confusing.
What is the mechanism of in
? How should I use it to avoid the above unexpected output?
And is it the problem of torch.tensor.T
? (When .T
is not used and initial b = torch.tensor([[1,4],[1,5],[1,6]])
, there also can be no expected output)
Upvotes: 1
Views: 1331
Reputation: 40618
The source code is as follows (edit: source code snippet found by @Rune):
def __contains__(self, element):
r"""Check if `element` is present in tensor
Args:
element (Tensor or scalar): element to be checked
for presence in current tensor"
"""
if has_torch_function_unary(self):
return handle_torch_function(Tensor.__contains__, (self,), self, element)
if isinstance(element, (torch.Tensor, Number)):
# type hint doesn't understand the __contains__ result array
return (element == self).any().item() # type: ignore[union-attr]
raise RuntimeError(
"Tensor.__contains__ only supports Tensor or scalar, but you passed in a %s." %
type(element)
)
The __contains__
(used by the 'in
' syntax: x in b
) operator is equivalent to applying torch.any
on the x == b
boolean condition:
>>> b = tensor([[1, 1, 1],
[4, 5, 6]])
>>> check_in = lambda x: torch.any(x == b.T)
Then
>>> check_in(torch.tensor([1,4]))
tensor(True)
>>> check_in(torch.tensor([2,1]))
tensor(False)
>>> check_in(torch.tensor([1,2]))
tensor(True)
>>> check_in(torch.tensor([2,5]))
tensor(True)
It is the position of the elements in the columns that matters not the exact match of the whole column.
.T
reverses the order of the dimension: equivalent to b.permute(1, 0)
and has no effect on the results. The only constraint you have when using in
is x
's size needs to match the shape of b[1]
. If you're working with b.T
then it will be b[0]
.
>>> check_in = lambda x: torch.any(x == b)
>>> check_in(torch.tensor([1,1,1]))
tensor(True)
>>> check_in(torch.tensor([5,4,2]))
tensor(False)
Upvotes: 1