dcfg
dcfg

Reputation: 891

pytorch - reciprocal of torch.gather

Given an input tensor x and a tensor of indices idxs, I want to retrieve all elements of x whose index is not present in idxs. That is, taking the opposite of the torch.gather function output.

Example with torch.gather:

>>> x = torch.arange(30).reshape(3,10)
>>> idxs = torch.tensor([[1,2,3], [4,5,6], [7,8,9]], dtype=torch.long)
>>> torch.gather(x, 1, idxs)
tensor([[ 1,  2,  3],
        [14, 15, 16],
        [27, 28, 29]])

What indeed I want to achieve is

tensor([[ 0,  4,  5,  6,  7,  8,  9],
        [10, 11, 12, 13, 17, 18, 19],
        [20, 21, 22, 23, 24, 25, 26]])

What could it be an effective and efficient implementation, possibly employing torch utilities? I wouldn't like to use any for-loops.

I'm assuming idxs has only unique elements in its deepest dimension. For example idxs would be the result of calling torch.topk.

Upvotes: 3

Views: 1817

Answers (1)

Ivan
Ivan

Reputation: 40778

You could be looking to construct a tensor of shape (x.size(0), x.size(1)-idxs.size(1)) (here (3, 7)). Which would correspond to the complementary indices of idxs, with regard to the shape of x, i.e.:

tensor([[0, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6]])

I propose to first build a tensor shaped like x that would reveal the positions we want to keep and those we want to discard, a sort of mask. This can be done using torch.scatter. This essentially scatters 0s at desired location, namely m[i, idxs[i][j]] = 0:

>>> m = torch.ones_like(x).scatter(1, idxs, 0)
tensor([[1, 0, 0, 0, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 0, 0, 0, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0]])

Then grab the non zeros (the complementary part of idxs). Select the 2nd indices on axis=1, and reshape according to the target tensor:

>>> idxs_ = m.nonzero()[:, 1].reshape(-1, x.size(1) - idxs.size(1))
tensor([[0, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6]])

Now you know what to do, right? Same as for the torch.gather example you gave, but this time with idxs_:

>>> torch.gather(x, 1, idxs_)
tensor([[ 0,  4,  5,  6,  7,  8,  9],
        [10, 11, 12, 13, 17, 18, 19],
        [20, 21, 22, 23, 24, 25, 26]])

In summary:

>>> idxs_ = torch.ones_like(x).scatter(1, idxs, 0) \
        .nonzero()[:, 1].reshape(-1, x.size(1) - idxs.size(1))

>>> torch.gather(x, 1, idxs_)

Upvotes: 3

Related Questions