Reputation: 1036
The task seems to be simple, but I cannot figure out how to do it.
So what I have are two tensors:
indices
with shape (2, 5, 2)
, where the last dimensions corresponds to indices in x and y dimensionvalue
with shape (2, 5, 2, 16, 16)
, where I want the last two dimensions to be selected with x and y indicesTo be more concrete, the indices are between 0 and 15 and I want to get an output:
out = value[:, :, :, x_indices, y_indices]
The shape of the output should therefore be of (2, 5, 2)
. Can anybody help me here? Thanks a lot!
Edit:
I tried the suggestion with gather, but unfortunately it does not seem to work (I changed the dimensions, but it doesn't matter):
First I generate a coordinate grid:
y_t = torch.linspace(-1., 1., 16, device='cpu').reshape(16, 1).repeat(1, 16).unsqueeze(-1)
x_t = torch.linspace(-1., 1., 16, device='cpu').reshape(1, 16).repeat(16, 1).unsqueeze(-1)
grid = torch.cat((y_t, x_t), dim=-1).permute(2, 0, 1).unsqueeze(0)
grid = grid.unsqueeze(1).repeat(1, 3, 1, 1, 1)
In the next step, I am creating some indices. In this case, I always take index 1:
indices = torch.ones([1, 3, 2], dtype=torch.int64)
Next, I am using your method:
indices = indices.unsqueeze(-1).unsqueeze(-1)
new_coords = torch.gather(grid, -1, indices).squeeze(-1).squeeze(-1)
Finally, I manually select index 1 for x and y coordinate:
new_coords_manual = grid[:, :, :, 1, 1]
This outputs the following new coordinates:
new_coords
tensor([[[-1.0000, -0.8667],
[-1.0000, -0.8667],
[-1.0000, -0.8667]]])
new_coords_manual
tensor([[[-0.8667, -0.8667],
[-0.8667, -0.8667],
[-0.8667, -0.8667]]])
As you can see, it only works for one dimension. Do you have an idea how to fix that?
Upvotes: 9
Views: 13494
Reputation: 1036
I figured it out, thanks again @Ivan for your help! :)
The problem was, that i unsqueezed on the last dimension, while I should have unsqueezed in the middle dimensions, so that the indices are at the end:
y_t = torch.linspace(-1., 1., 16, device='cpu').reshape(16, 1).repeat(1, 16).unsqueeze(-1)
x_t = torch.linspace(-1., 1., 16, device='cpu').reshape(1, 16).repeat(16, 1).unsqueeze(-1)
grid = torch.cat((y_t, x_t), dim=-1).permute(2, 0, 1).unsqueeze(0)
grid = grid.unsqueeze(1).repeat(2, 3, 1, 1, 1)
indices = torch.ones([2, 3, 2], dtype=torch.int64).unsqueeze(-2).unsqueeze(-2)
new_coords = torch.gather(grid, 3, indices).squeeze(-2).squeeze(-2)
new_coords_manual = grid[:, :, :, 1, 1]
Now new_coords
equals new_coords_manual
.
Upvotes: 3
Reputation: 40618
What you could do is flatten the first three axes together and apply torch.gather
:
>>> grid.flatten(start_dim=0, end_dim=2).shape
torch.Size([6, 16, 16])
>>> torch.gather(grid.flatten(0, 2), axis=1, indices)
tensor([[[-0.8667, -0.8667],
[-0.8667, -0.8667],
[-0.8667, -0.8667]]])
As explained on the documentation page, this will perform:
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
Upvotes: 3