Reputation: 13
I am trying to vectorize the following:
n = torch.zeros_like(x)
for i in range(x.shape[0]):
for j in range(x.shape[1]):
for k in range(x.shape[2]):
n[i, j, k] = p[i, x[i, j, k], j, k]
I tried doing something like
n = p[:, x, ...]
but I just get an error that I ran out of memory, which isn't very helpful. I think the problem with this is that instead of getting the value of x at the correct index it is trying to index the entirety of x, but I am not sure how I would go about fixing that if that is the problem.
Upvotes: 1
Views: 92
Reputation: 114230
This looks like a perfect use-case for broadcasted fancy indices. np.ogrid
is a valuable tool here, or you can manually reshape your ranges:
i, j, k = np.ogrid[:x.shape[0], :x.shape[1], :x.shape[2]]
n = p[i, x, j, k]
This black magic works because the index into ogrid
returns three arrays that broadcast into the same shape as x
. Therefore the final extraction from p
will have that shape. The indexing is trivial after that. Another way to write it is:
i = np.arange(x.shape[0]).reshape(-1, 1, 1)
j = np.arange(x.shape[1]).reshape(1, -1, 1)
k = np.arange(x.shape[2]).reshape(1, 1, -1)
n = p[i, x, j, k]
Upvotes: 2