random_guy7531
random_guy7531

Reputation: 23

Single shot multi-dimension indexing in torch - perhaps with index_select or gather?

I am performing a multi-index re-arrangement of a matrix based upon its correspondence data. Right now, I and doing this with a pair of index_select calls, but this is very memory inefficient (n^2 in terms of memory usage), and is not exactly ideal in terms of computation efficiency either. Is there some way that I can boil my operation down into a single .gather or .index_select call?

What I essentially want to do is when given a source array of shape (I,J,K), and an array of indices of shape (I,J,2), produce a result which meets the condition:

result[i][j][:] = source[idx[i][j][0]] [idx[i][j][1]] [:]

Here's a runnable toy example of how I'm doing things right now:

source = torch.tensor([[1,2,3], [4,5,6], [7,8,9], [10,11,12]])
indices = torch.tensor([[[2,2],[3,1],[0,2]],[[0,2],[0,1],[0,2]],[[0,2],[0,1],[0,2]],[[0,2],[0,1],[0,2]]])

ax1 = torch.index_select(source,0,indices[:,:,0].flatten())
ax2 = torch.index_select(ax1, 1, indices[:,:,1].flatten())

result = ax2.diagonal().reshape(indices.shape(0), indices.shape(1))

This approach works for me only because my images are rather small, so they fit into memory even with the diagonalization issue. Regardless, I am producing a pretty massive amount of data that doesn't need to be. Furthermore, if K becomes large, then this issue gets worse exponentially. Perhaps I'm just missing something obvious in the documentation, but I feel like this is a problem somebody else has to have run into before that can help me out!

Upvotes: 2

Views: 882

Answers (1)

jodag
jodag

Reputation: 22284

You already have your indices in nice form for integer array indexing so we can simply do

result = source[indices[..., 0], indices[..., 1], ...]

Upvotes: 1

Related Questions