Reputation: 4679
I want to find the index of the smallest tensor (by some key function) in a list li
.
So I did min
and afterwards li.index(min_el)
. My MWE suggests that somehow tensors don't work with index
.
import torch
li=[torch.ones(1,1), torch.zeros(2,2)]
li.index(li[0])
0
li.index(li[1])
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File ".../local/lib/python2.7/site-packages/torch/tensor.py", line 330, in __eq__
return self.eq(other)
RuntimeError: inconsistent tensor size at /b/wheel/pytorch-src/torch/lib/TH/generic/THTensorMath.c:2679
I can of course make my own index function that first checks for size and then element wise. E.g.
def index(list, element):
for i,el in enumerate(list):
if el.size() == element.size():
diff = el - element
if (1- diff.byte()).all():
return i
return -1
I was just wondering, why doesn't index
work?
Is there maybe a smart way of doing this not by hand that I'm missing?
Upvotes: 0
Views: 559
Reputation: 8744
You could directly find the index by using enumerate
and a key function that operates on the second element of each tuple. For instance, if your key function compares the first element of the each tensor, you could use
ix, _ = min(enumerate(li), key=lambda x: x[1][0, 0])
I think the reason why index
succeeds for the first element is that Python probably does something that is equivalent to x is value or x == value
, where x
is the current element in the list. Since value is value
, the failing equality comparison never happens. This is why li.index(li[0])
works but this fails:
y = torch.ones(1,1)
li.index(y)
Upvotes: 1