Reputation: 15
I have N groups of C-dimension points. In each groups there are M points. So, there is a tensor of (N, M, C). Let's call it features.
I calculated the maximum element and the index through M dimension, to find the maximum points for each C dimension (a max pooling operation), resulting max tensor (N, 1, C) and index tensor (N, 1, C).
I have another tensor of shape (N, M, 3) storing the geometric coordinates of those N*M high-dimention points. Now, I want to use the index from the maximum points in each C dimension, to get the coordinates of all those maximum points.
For example, N=2, M=4, C=6.
The coordinate tensor, whose shape is (2, 4, 3):
[[[1, 2, 3]
[4, 5, 6]
[7, 8, 9]
[8, 7, 6]]
[11, 12, 13]
[14, 15, 16]
[17, 18, 19]
[18, 17, 16]]]
The indices tensor, whose shape is (2, 1, 6):
[[[0, 1, 2, 1, 2, 3]]
[[1, 2, 3, 2, 1, 0]]]
For example, the first element in indices is 0, I want to grab [1, 2, 3] from the coordinate tensor out. For the second element (1), I want to grab [4, 5, 6] out. For the third element in the next dimension (3), I want to grab [18, 17, 16] out.
The result tensor will look like:
[[[1, 2, 3] # 0
[4, 5, 6] # 1
[7, 8, 9] # 2
[4, 5, 6] # 1
[7, 8, 9] # 2
[8, 7, 6]] # 3
[[14, 15, 16] # 1
[17, 18, 19] # 2
[18, 17, 16] # 3
[17, 18, 19] # 2
[14, 15, 16] # 1
[11, 12, 13]]]# 0
whose shape is (2, 6, 3).
I tried to use torch.gather but I can not get it worked. I wrote a naive algorithm enumerating all N groups, but indeed it is slow, even using TorchScript's JIT. So, how to write this efficiently in pytorch?
Upvotes: 1
Views: 247
Reputation: 22184
You can use integer array indexing combined with broadcasting semantics to get your result.
import torch
x = torch.tensor([
[[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[8, 7, 6]],
[[11, 12, 13],
[14, 15, 16],
[17, 18, 19],
[18, 17, 16]],
])
i = torch.tensor([[[0, 1, 2, 1, 2, 3]],
[[1, 2, 3, 2, 1, 0]]])
# rows is shape [2, 1], cols is shape [2, 6]
rows = torch.arange(x.shape[0]).type_as(i).unsqueeze(1)
cols = i.squeeze(1)
# y is [2, 6, ...]
y = x[rows, cols]
Upvotes: 1