abhinavkulkarni
abhinavkulkarni

Reputation: 2409

np.where IndexError exception

I have a very simple code as following:

import numpy as np
num_classes = 12
im_pred = np.random.randint(0, num_classes, (224, 244))
img = np.zeros((224, 224, 3))
print(im_pred.shape)
#(224, 244)
print(img.shape)
#(224, 224, 3)
for i in range(num_classes):
    img[np.where(im_pred==i), :] = [225, 0, 0]

Traceback (most recent call last):
File "", line 2, in <module>
IndexError: index 227 is out of bounds for axis 0 with size 224

x, y = np.where(im_pred==i)
print(np.max(x), np.max(y))
#223 243

Why I am getting an IndexError? As for my understanding of np.where, the values of indices returned should be less than 224.

Let me know. I am starting to wonder if the numpy installation is buggy.

Thanks.

Upvotes: 0

Views: 572

Answers (2)

bnaecker
bnaecker

Reputation: 6450

The problem is that you've made img and img_pred of different sizes:

im_pred.shape == (224, 244)

while

img.shape == (224, 224, 3)

The second axes have different sizes.

But once you fix that, there's a simple optimization to be made. There's no need for np.where here. Just use direct logical indexing:

for i in range(num_classes):
    img[im_pred == i, 0] = 255

Note I'm also leaving off the two zeros, since you initialize the array with zeros on construction.

Upvotes: 1

jfish003
jfish003

Reputation: 1332

No Numpy is not buggy. Look at how you defined im_pred for a second, you are drawing a random integer between 0 and 11 for an array which has size 224 by 244. So the reason it is throwing an error is because the dimension of size 244 is too large for your variable img which is only 224 by 224 by 3. I think you may have meant for both to have the same 1rst and second dimensions, something like

img = np.zeros((224,244,3)) 

Upvotes: 1

Related Questions