Reputation: 23
I am performing a multi-index re-arrangement of a matrix based upon its correspondence data. Right now, I and doing this with a pair of index_select calls, but this is very memory inefficient (n^2 in terms of memory usage), and is not exactly ideal in terms of computation efficiency either. Is there some way that I can boil my operation down into a single .gather or .index_select call?
What I essentially want to do is when given a source array of shape (I,J,K), and an array of indices of shape (I,J,2), produce a result which meets the condition:
result[i][j][:] = source[idx[i][j][0]] [idx[i][j][1]] [:]
Here's a runnable toy example of how I'm doing things right now:
source = torch.tensor([[1,2,3], [4,5,6], [7,8,9], [10,11,12]])
indices = torch.tensor([[[2,2],[3,1],[0,2]],[[0,2],[0,1],[0,2]],[[0,2],[0,1],[0,2]],[[0,2],[0,1],[0,2]]])
ax1 = torch.index_select(source,0,indices[:,:,0].flatten())
ax2 = torch.index_select(ax1, 1, indices[:,:,1].flatten())
result = ax2.diagonal().reshape(indices.shape(0), indices.shape(1))
This approach works for me only because my images are rather small, so they fit into memory even with the diagonalization issue. Regardless, I am producing a pretty massive amount of data that doesn't need to be. Furthermore, if K becomes large, then this issue gets worse exponentially. Perhaps I'm just missing something obvious in the documentation, but I feel like this is a problem somebody else has to have run into before that can help me out!
Upvotes: 2
Views: 882
Reputation: 22284
You already have your indices in nice form for integer array indexing so we can simply do
result = source[indices[..., 0], indices[..., 1], ...]
Upvotes: 1