Python/Numpy - Vectorized implementation of this for loop?

This is a lethargic implementation of a cloud mask based on interpolation across temporal channels of a satellite image. The image array is of shape (n_samples x n_months x width x height x channels). The channels are not just RGB, but also from the invisible spectrum such as SWIR, NIR, etc. One of the channels (or bands, in the satellite image world) is a cloud mask that tells me 0 means "no cloud" and 1024 or 2048 means "cloud" in that pixel.

I'm using this pixel-wise cloud mask channel to change the values on all remaining channels by interpolation between the previous/next month. This implementation is super slow and I'm having a hard time coming up with vectorized implementation.

  1. Is it possible to vectorize this implementation? What is it?
  2. Any suggestion on how to deduce the logic of vectorized implementation of complex array operations? In other words, how do I learn the art of vectorization?

I'm a novice, so please excuse my ignorance.

n_samples = 1055
n_months = 12
width = 40
height = 40
channels = 13 # channel 13 is the cloud mask, based on which the first 12 channel pixels are interpolated)

# This function fills nan values in a list by interpolation
def fill_nan(y):
    nans = np.isnan(y)
    x = lambda z: z.nonzero()[0]
    y[nans]= np.interp(x(nans), x(~nans), y[~nans])
    return y

#for loop to first fill cloudy pixels with nan
for sample in range(1055):
    for temp in range(12):
        for w in range(40):
            for h in range(40):
                if Xtest[sample,temp,w,h,13] > 0:
                    Xtest[sample,temp,w,h,:12] = np.nan

#for loop to fill nan with interpolated values
for sample in range(1055):
    for w in range(40):
        for h in range(40):
            for ch in range(12):
                Xtest[sample,: , w, h, ch] = fill_nan(Xtest[sample,: , w, h, ch])

Upvotes: 0

Views: 93

Answers (1)

Ananda
Ananda

Reputation: 3272

For the first loop,

import numpy as np

Xtest = np.random.rand(10, 3, 2, 4, 14)
Xtest_v = Xtest.copy()

for sample in range(10):
    for temp in range(3):
        for w in range(2):
            for h in range(4):
                if Xtest[sample,temp,w,h,13] > 0:
                    Xtest[sample,temp,w,h,:12] = np.nan

Xtest_v[..., :12][Xtest_v[..., 13]>0] = np.nan

print(np.nansum(Xtest))
print(np.nansum(Xtest_v))

You can verify that both the arrays are the same by printing out the sum ignoring nans.

Upvotes: 2

Related Questions