peer
peer

Reputation: 4679

How to find the index of a tensor in a list?

I want to find the index of the smallest tensor (by some key function) in a list li. So I did min and afterwards li.index(min_el). My MWE suggests that somehow tensors don't work with index.

import torch
li=[torch.ones(1,1), torch.zeros(2,2)]
li.index(li[0])
0
li.index(li[1])
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File ".../local/lib/python2.7/site-packages/torch/tensor.py", line 330, in __eq__
    return self.eq(other)
RuntimeError: inconsistent tensor size at /b/wheel/pytorch-src/torch/lib/TH/generic/THTensorMath.c:2679

I can of course make my own index function that first checks for size and then element wise. E.g.

def index(list, element):
    for i,el in enumerate(list):
        if el.size() == element.size():
            diff = el - element
            if (1- diff.byte()).all():
               return i
    return -1

I was just wondering, why doesn't index work? Is there maybe a smart way of doing this not by hand that I'm missing?

Upvotes: 0

Views: 559

Answers (1)

nnnmmm
nnnmmm

Reputation: 8744

You could directly find the index by using enumerate and a key function that operates on the second element of each tuple. For instance, if your key function compares the first element of the each tensor, you could use

ix, _ = min(enumerate(li), key=lambda x: x[1][0, 0])

I think the reason why index succeeds for the first element is that Python probably does something that is equivalent to x is value or x == value, where x is the current element in the list. Since value is value, the failing equality comparison never happens. This is why li.index(li[0]) works but this fails:

y = torch.ones(1,1)
li.index(y)

Upvotes: 1

Related Questions