CYD
CYD

Reputation: 33

Delete whole 3d array if contains any nan

I have a set of images represented by a 3d ndarray. Abstractly what I want to do is, delete an entire image if any of its pixel values is a nan. Imagine we got the following ndarray:

a = np.arange(18).reshape(3, 2, 3)
a = 1.0 * a
a[0][0][1] = np.nan
a[1][0][0] = np.nan
a
[[[ 0. nan  2.]
  [ 3.  4.  5.]]

 [[nan  7.  8.]
  [ 9. 10. 11.]]

 [[12. 13. 14.]
  [15. 16. 17.]]]

Now what I want to get is a function that given that ndarray returns True, True, False. In order to finally use np.delete.

I have try the following, which works:

np.delete(a, [np.isnan(image.flatten()).any() for image in a], axis=0)
array([[[12., 13., 14.],
        [15., 16., 17.]]]))

However, I find it hard to believe that there isn't a function in numpy that is more efficient, and since I have a lot of images I would like to optimise it as much as possible.

Upvotes: 3

Views: 308

Answers (1)

Sadman Sakib
Sadman Sakib

Reputation: 595

As already answered by Michael Szczesny, more Pythonic way would be:

filtered_images=a[~np.isnan(a).any(axis=(2,1))]

If that piece of code is hard to understand, then consider extracting each image with a for loop as follows:

filtered_images=list()
for value in a:
  if(np.isnan(value).any()!=True):
    filtered_images.append(value)

Both approaches should give you similar output!

Upvotes: 1

Related Questions