nicememe420
nicememe420

Reputation: 13

How can I vectorize these nested loops in Python?

I am trying to vectorize the following:

n = torch.zeros_like(x)
for i in range(x.shape[0]):
    for j in range(x.shape[1]):
        for k in range(x.shape[2]):
            n[i, j, k] = p[i, x[i, j, k], j, k]

I tried doing something like

n = p[:, x, ...]

but I just get an error that I ran out of memory, which isn't very helpful. I think the problem with this is that instead of getting the value of x at the correct index it is trying to index the entirety of x, but I am not sure how I would go about fixing that if that is the problem.

Upvotes: 1

Views: 92

Answers (1)

Mad Physicist
Mad Physicist

Reputation: 114230

This looks like a perfect use-case for broadcasted fancy indices. np.ogrid is a valuable tool here, or you can manually reshape your ranges:

i, j, k = np.ogrid[:x.shape[0], :x.shape[1], :x.shape[2]]
n = p[i, x, j, k]

This black magic works because the index into ogrid returns three arrays that broadcast into the same shape as x. Therefore the final extraction from p will have that shape. The indexing is trivial after that. Another way to write it is:

i = np.arange(x.shape[0]).reshape(-1, 1, 1)
j = np.arange(x.shape[1]).reshape(1, -1, 1)
k = np.arange(x.shape[2]).reshape(1, 1, -1)
n = p[i, x, j, k]

Upvotes: 2

Related Questions