rune
rune

Reputation: 40

What is the mechanism of "torch.Tensor in torch.Tensor" in python and why is there such a confusing phenomenon?

environment:

google colab

Python 3.7.11

torch 1.9.0+cu102

code and output

import torch
b = torch.tensor([[1,1,1],[4,5,6]])
print(b.T)

print(torch.tensor([1,4]) in b.T) # 
print(torch.tensor([2,1]) in b.T) #
print(torch.tensor([1,2]) in b.T) # Not as expected
print(torch.tensor([2,5]) in b.T) # Not as expected

----------------------------------------------------------
tensor([[1, 4],
        [1, 5],
        [1, 6]])
True
False
True
True

probleam

I want to judge whether one tensor is in another.But the result above is confusing.

What is the mechanism of in? How should I use it to avoid the above unexpected output?

And is it the problem of torch.tensor.T? (When .T is not used and initial b = torch.tensor([[1,4],[1,5],[1,6]]), there also can be no expected output)

Upvotes: 1

Views: 1331

Answers (1)

Ivan
Ivan

Reputation: 40618

The source code is as follows (edit: source code snippet found by @Rune):

def __contains__(self, element):
    r"""Check if `element` is present in tensor
    Args:
        element (Tensor or scalar): element to be checked
            for presence in current tensor"
    """
    if has_torch_function_unary(self):
        return handle_torch_function(Tensor.__contains__, (self,), self, element)
    if isinstance(element, (torch.Tensor, Number)):
        # type hint doesn't understand the __contains__ result array
        return (element == self).any().item()  # type: ignore[union-attr]

    raise RuntimeError(
        "Tensor.__contains__ only supports Tensor or scalar, but you passed in a %s." %
        type(element)
    )

The __contains__ (used by the 'in' syntax: x in b) operator is equivalent to applying torch.any on the x == b boolean condition:

>>> b = tensor([[1, 1, 1],
                [4, 5, 6]])

>>> check_in = lambda x: torch.any(x == b.T)

Then

>>> check_in(torch.tensor([1,4]))
tensor(True)

>>> check_in(torch.tensor([2,1]))
tensor(False)

>>> check_in(torch.tensor([1,2]))
tensor(True)

>>> check_in(torch.tensor([2,5]))
tensor(True)

It is the position of the elements in the columns that matters not the exact match of the whole column.


.T reverses the order of the dimension: equivalent to b.permute(1, 0) and has no effect on the results. The only constraint you have when using in is x's size needs to match the shape of b[1]. If you're working with b.T then it will be b[0].

>>> check_in = lambda x: torch.any(x == b)

>>> check_in(torch.tensor([1,1,1]))
tensor(True)

>>> check_in(torch.tensor([5,4,2]))
tensor(False)

Upvotes: 1

Related Questions