shaunakde
shaunakde

Reputation: 3108

Advice on vectorizing block-wise operations in Numpy

I am trying to implement a series of statistical operations, and I need help vectorizing my code.

The idea is to extract NxN patches from two images, the compute a distance metric between these two patches.

To do this, first I construct the patches with the following loops:

params = []
for i in range(0,patch1.shape[0],1):
    for j in range(0,patch1.shape[1],1):
        window1 = np.copy(imga[i:i+N,j:j+N]).flatten()
        window2 = np.copy(imgb[i:i+N,j:j+N]).flatten()
        params.append((window1, window2))
print(f"We took {time()- t0:2.2f} seconds to prepare {len(params)/1e6} million patches.")

This takes about 10 seconds to complete, and I'm not overly concerned with the pre-processing time. The steps below that follow are the steps I want to optimize.

After this, in an attempt to speed up processing I used multipool to compute the actual results. The function that contains the actual computation is as follows:

@njit
def cauchy_schwartz(imga, imgb):
    p, _ = np.histogram(imga, bins=10)
    p = p/np.sum(p)
    q, _ = np.histogram(imgb, bins=10)
    q = q/np.sum(q)

    n_d = np.array(np.sum(p * q)) 
    d_d = np.array(np.sum(np.power(p, 2) * np.power(q, 2)))
    return -1.0 * np.log10( n_d, d_d)

I use this structure to process all the patches:

def f(param):
    return cauchy_schwartz(*param)

with Pool(4) as p:
    r = list(tqdm.tqdm(p.imap(f,params), total=len(params)))

I am sure there must be something much more elegant to do this, because if I send the whole 10Kpx by 10Kpx images to the cauchy_schwartz function it processes everything in under a second, but with my approach, even on 4 cores it takes a long time.

My mental model is how blockproc in matlab works - and I ended up writing this code in that pattern. I would appreciate any advice on improving the performance of this code.

Upvotes: 6

Views: 360

Answers (2)

meTchaikovsky
meTchaikovsky

Reputation: 7676

By using apply_along_axis, you can get rid of cauchy_schwartz. Since you are not overly concerned with the pre-processing time, assume you have obtained the array params which contains the flattened patches

params = np.random.rand(3,2,100)

as you can see the shape of params is (3,2,100), the three numbers 3, 2, and 100 are just randomly chosen to create an auxiliary array to demonstrate the logic of using apply_along_axis. 3 corresponds to the number of patches you have (determined by the patch shape and the image size), 2 corresponds to the two images, and 100 corresponds to the flattened patches. Therefore, the axes of params is (idx of patches, idx of images, idx of entries of a flattened patch), this exactly matches the list params created by your code

params = []
for i in range(0,patch1.shape[0],1):
    for j in range(0,patch1.shape[1],1):
        window1 = np.copy(imga[i:i+N,j:j+N]).flatten()
        window2 = np.copy(imgb[i:i+N,j:j+N]).flatten()
        params.append((window1, window2))

With the auxiliary array params, here is my solution:

hist = np.apply_along_axis(lambda x: np.histogram(x,bins=11)[0],2,params)
hist = hist / np.sum(hist,axis=2)[...,None]

n_d = np.sum(np.product(hist,axis=1),axis=1)
d_d = np.sum(np.product(np.power(hist,2),axis=1),axis=1)
res = -1.0 * np.log10(n_d, d_d)

Upvotes: 3

auraham
auraham

Reputation: 1739

First of all, profile your code to identify the bottleneck. You can use https://mg.pov.lt/profilehooks/. I think the bottleneck is located in the creation of the patches since you are creating a copy of the patches for the processes. You could use less memory by passing the indices of the patches only:

params = []
for i in range(0,patch1.shape[0],1):
    for j in range(0,patch1.shape[1],1):
        start, end = (i,i+N), (j,j+N)
        params.append((start, end))

Then, assuming imga and imgb are global, you can create the patches from cauchy_schwartz function as shown below:

@njit
def cauchy_schwartz(start, end):

    a,b = start; c,d = end
    window1 = np.copy(imga[a:b, c:d]).flatten()
    window2 = np.copy(imgb[a:b, c:d]).flatten()

    # process patches window1 and window2

Upvotes: 0

Related Questions