Reputation: 891
Given an input tensor x
and a tensor of indices idxs
, I want to retrieve all elements of x
whose index is not present in idxs
. That is, taking the opposite of the torch.gather
function output.
Example with torch.gather
:
>>> x = torch.arange(30).reshape(3,10)
>>> idxs = torch.tensor([[1,2,3], [4,5,6], [7,8,9]], dtype=torch.long)
>>> torch.gather(x, 1, idxs)
tensor([[ 1, 2, 3],
[14, 15, 16],
[27, 28, 29]])
What indeed I want to achieve is
tensor([[ 0, 4, 5, 6, 7, 8, 9],
[10, 11, 12, 13, 17, 18, 19],
[20, 21, 22, 23, 24, 25, 26]])
What could it be an effective and efficient implementation, possibly employing torch utilities? I wouldn't like to use any for-loops.
I'm assuming idxs
has only unique elements in its deepest dimension. For example idxs
would be the result of calling torch.topk
.
Upvotes: 3
Views: 1817
Reputation: 40778
You could be looking to construct a tensor of shape (x.size(0), x.size(1)-idxs.size(1))
(here (3, 7)
). Which would correspond to the complementary indices of idxs
, with regard to the shape of x
, i.e.:
tensor([[0, 4, 5, 6, 7, 8, 9],
[0, 1, 2, 3, 7, 8, 9],
[0, 1, 2, 3, 4, 5, 6]])
I propose to first build a tensor shaped like x
that would reveal the positions we want to keep and those we want to discard, a sort of mask. This can be done using torch.scatter
. This essentially scatters 0
s at desired location, namely m[i, idxs[i][j]] = 0
:
>>> m = torch.ones_like(x).scatter(1, idxs, 0)
tensor([[1, 0, 0, 0, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 0, 0, 0, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 0, 0, 0]])
Then grab the non zeros (the complementary part of idxs
). Select the 2nd indices on axis=1
, and reshape according to the target tensor:
>>> idxs_ = m.nonzero()[:, 1].reshape(-1, x.size(1) - idxs.size(1))
tensor([[0, 4, 5, 6, 7, 8, 9],
[0, 1, 2, 3, 7, 8, 9],
[0, 1, 2, 3, 4, 5, 6]])
Now you know what to do, right? Same as for the torch.gather
example you gave, but this time with idxs_
:
>>> torch.gather(x, 1, idxs_)
tensor([[ 0, 4, 5, 6, 7, 8, 9],
[10, 11, 12, 13, 17, 18, 19],
[20, 21, 22, 23, 24, 25, 26]])
In summary:
>>> idxs_ = torch.ones_like(x).scatter(1, idxs, 0) \
.nonzero()[:, 1].reshape(-1, x.size(1) - idxs.size(1))
>>> torch.gather(x, 1, idxs_)
Upvotes: 3