David Xiong
David Xiong

Reputation: 15

Pytorch gather question (3D Computer Vision)

I have N groups of C-dimension points. In each groups there are M points. So, there is a tensor of (N, M, C). Let's call it features.

I calculated the maximum element and the index through M dimension, to find the maximum points for each C dimension (a max pooling operation), resulting max tensor (N, 1, C) and index tensor (N, 1, C).

I have another tensor of shape (N, M, 3) storing the geometric coordinates of those N*M high-dimention points. Now, I want to use the index from the maximum points in each C dimension, to get the coordinates of all those maximum points.

For example, N=2, M=4, C=6.

The coordinate tensor, whose shape is (2, 4, 3):

[[[1, 2, 3]
  [4, 5, 6]
  [7, 8, 9]
  [8, 7, 6]]

 [11, 12, 13]
 [14, 15, 16]
 [17, 18, 19]
 [18, 17, 16]]]

The indices tensor, whose shape is (2, 1, 6):

[[[0, 1, 2, 1, 2, 3]]
 [[1, 2, 3, 2, 1, 0]]]

For example, the first element in indices is 0, I want to grab [1, 2, 3] from the coordinate tensor out. For the second element (1), I want to grab [4, 5, 6] out. For the third element in the next dimension (3), I want to grab [18, 17, 16] out.

The result tensor will look like:

[[[1, 2, 3]  # 0
  [4, 5, 6]  # 1
  [7, 8, 9]  # 2
  [4, 5, 6]  # 1
  [7, 8, 9]  # 2
  [8, 7, 6]] # 3

 [[14, 15, 16] # 1
  [17, 18, 19] # 2
  [18, 17, 16] # 3
  [17, 18, 19] # 2
  [14, 15, 16] # 1
  [11, 12, 13]]]# 0

whose shape is (2, 6, 3).

I tried to use torch.gather but I can not get it worked. I wrote a naive algorithm enumerating all N groups, but indeed it is slow, even using TorchScript's JIT. So, how to write this efficiently in pytorch?

Upvotes: 1

Views: 247

Answers (1)

jodag
jodag

Reputation: 22184

You can use integer array indexing combined with broadcasting semantics to get your result.

import torch

x = torch.tensor([
    [[1, 2, 3], 
     [4, 5, 6], 
     [7, 8, 9], 
     [8, 7, 6]],
    [[11, 12, 13],
     [14, 15, 16],
     [17, 18, 19],
     [18, 17, 16]],
])

i = torch.tensor([[[0, 1, 2, 1, 2, 3]],
                  [[1, 2, 3, 2, 1, 0]]])

# rows is shape [2, 1], cols is shape [2, 6]
rows = torch.arange(x.shape[0]).type_as(i).unsqueeze(1)
cols = i.squeeze(1)

# y is [2, 6, ...]
y = x[rows, cols]

Upvotes: 1

Related Questions