spadel
spadel

Reputation: 1036

How to select indices according to another tensor in pytorch

The task seems to be simple, but I cannot figure out how to do it.

So what I have are two tensors:

To be more concrete, the indices are between 0 and 15 and I want to get an output:

out = value[:, :, :, x_indices, y_indices]

The shape of the output should therefore be of (2, 5, 2). Can anybody help me here? Thanks a lot!

Edit:

I tried the suggestion with gather, but unfortunately it does not seem to work (I changed the dimensions, but it doesn't matter):

First I generate a coordinate grid:

y_t = torch.linspace(-1., 1., 16, device='cpu').reshape(16, 1).repeat(1, 16).unsqueeze(-1)
x_t = torch.linspace(-1., 1., 16, device='cpu').reshape(1, 16).repeat(16, 1).unsqueeze(-1)
grid = torch.cat((y_t, x_t), dim=-1).permute(2, 0, 1).unsqueeze(0)
grid = grid.unsqueeze(1).repeat(1, 3, 1, 1, 1)

In the next step, I am creating some indices. In this case, I always take index 1:

indices = torch.ones([1, 3, 2], dtype=torch.int64)

Next, I am using your method:

indices = indices.unsqueeze(-1).unsqueeze(-1)
new_coords = torch.gather(grid, -1, indices).squeeze(-1).squeeze(-1)

Finally, I manually select index 1 for x and y coordinate:

new_coords_manual = grid[:, :, :, 1, 1]

This outputs the following new coordinates:

new_coords
tensor([[[-1.0000, -0.8667],
         [-1.0000, -0.8667],
         [-1.0000, -0.8667]]])

new_coords_manual
tensor([[[-0.8667, -0.8667],
         [-0.8667, -0.8667],
         [-0.8667, -0.8667]]])

As you can see, it only works for one dimension. Do you have an idea how to fix that?

Upvotes: 9

Views: 13494

Answers (2)

spadel
spadel

Reputation: 1036

I figured it out, thanks again @Ivan for your help! :)

The problem was, that i unsqueezed on the last dimension, while I should have unsqueezed in the middle dimensions, so that the indices are at the end:

y_t = torch.linspace(-1., 1., 16, device='cpu').reshape(16, 1).repeat(1, 16).unsqueeze(-1)
x_t = torch.linspace(-1., 1., 16, device='cpu').reshape(1, 16).repeat(16, 1).unsqueeze(-1)
grid = torch.cat((y_t, x_t), dim=-1).permute(2, 0, 1).unsqueeze(0)
grid = grid.unsqueeze(1).repeat(2, 3, 1, 1, 1)

indices = torch.ones([2, 3, 2], dtype=torch.int64).unsqueeze(-2).unsqueeze(-2)
new_coords = torch.gather(grid, 3, indices).squeeze(-2).squeeze(-2)

new_coords_manual = grid[:, :, :, 1, 1]

Now new_coords equals new_coords_manual.

Upvotes: 3

Ivan
Ivan

Reputation: 40618

What you could do is flatten the first three axes together and apply torch.gather:

>>> grid.flatten(start_dim=0, end_dim=2).shape
torch.Size([6, 16, 16])

>>> torch.gather(grid.flatten(0, 2), axis=1, indices)
tensor([[[-0.8667, -0.8667],
         [-0.8667, -0.8667],
         [-0.8667, -0.8667]]])

As explained on the documentation page, this will perform:

out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1

Upvotes: 3

Related Questions