Reputation: 301
I have a 2D tensor with some nonzero element in each row like this:
import torch
tmp = torch.tensor([[0, 0, 1, 0, 1, 0, 0],
[0, 0, 0, 1, 1, 0, 0]], dtype=torch.float)
I want a tensor containing the index of first nonzero element in each row:
indices = tensor([2],
[3])
How can I calculate it in Pytorch?
Upvotes: 12
Views: 13367
Reputation: 41
Building on @Seppo answer, we can eliminate the assumption that "all the nonzero values are equal" by simply creating a mask from the original tensor and then using pytorch functions
# tmp = some tensor of whatever shape and values
indices = torch.argmax((tmp != 0).to(dtype=torch.int), dim=-1)
However, if a row of tensor is all zeros, then the information returned isn't the first nonzero element index. I suppose the nature of the question makes it that this case doesn't occur.
Upvotes: 3
Reputation: 3663
Assuming that all the nonzero values are equal, argmax
returns the first index.
tmp = torch.tensor([[0, 0, 1, 0, 1, 0, 0],
[0, 0, 0, 1, 1, 0, 0]])
indices = tmp.argmax(1)
Upvotes: 7
Reputation: 141
I have simplified Iman's approach to do the following:
idx = torch.arange(tmp.shape[1], 0, -1)
tmp2= tmp * idx
indices = torch.argmax(tmp2, 1, keepdim=True)
Upvotes: 14
Reputation: 301
I could find a tricky answer for my question:
tmp = torch.tensor([[0, 0, 1, 0, 1, 0, 0],
[0, 0, 0, 1, 1, 0, 0]], dtype=torch.float)
idx = reversed(torch.Tensor(range(1,8)))
print(idx)
tmp2= torch.einsum("ab,b->ab", (tmp, idx))
print(tmp2)
indices = torch.argmax(tmp2, 1, keepdim=True)
print(indeces)
The result is:
tensor([7., 6., 5., 4., 3., 2., 1.])
tensor([[0., 0., 5., 0., 3., 0., 0.],
[0., 0., 0., 4., 3., 0., 0.]])
tensor([[2],
[3]])
Upvotes: 7