Athul
Athul

Reputation: 13

Speeding up nested for loops in Python

I have a nested loop system written in python as follows:

for yt in range(dims[1]):
  for xt in range(dims[2]):
    for yp in range(dims[1]):
       for xp in range(dims[2]):
           corr[yt,xt,yp,xp] = sp.spearmanr(prec_tar[:,yt,xt],prec_pre[:,yp,xp],axis=0)[0] 
           corr2[yt,xt,yp,xp] = sp.spearmanr(prec_tar[:,yt,xt],prec_pre2[:,yp,xp],axis=0)[0]
           corr3[yt,xt,yp,xp] = sp.spearmanr(prec_tar[:,yt,xt],prec_pre3[:,yp,xp],axis=0)[0]

Where dims has the shape (1710, 69, 21) and corr, corr2 and corr3 are all xarray Dataarray with empty NumPy arrays of shape (69,21,69,21).

Now, the issue is that this script takes forever to finish (~ 6+ hours). I'm not sure if the nested loops setup is causing it or if sp.spearmanr is the culprit (or perhaps both). I'm looking for ways to make this run faster, specifically, I'm wondering if it is possible to make use of parallel processing. Other tips are also welcome. Thanks in advance!

Edit: I should also add that prec_tar, prec_pre, prec_pre2, and prec_pre3 all have the same shape as dims (i.e., (1710, 69, 21)).

Upvotes: 0

Views: 147

Answers (3)

Athul
Athul

Reputation: 13

This is a working solution to this problem based on @aaron.spring's suggestion. I hope this helps someone someday.

# Problem at hand: Very slow.
t1 = time.time()
for i in range(dims[1]):   #dims = ((1000, 4, 5))
    for j in range(dims[2]):
        for x in range(dims[1]):
            for y in range(dims[2]):
                acorrb[i,j,x,y] = spearmanr(a[:,i,j], b[:,x,y], dim='time')
t2 = time.time()
print(t2-t1)  # 0.3600752353668213

# Faster solution based on xarray's vectorized indexing and using  xskillscore.spearman_r instead of spearmanr from scipy.stats. 

ind_i = xr.DataArray(range(dims[1]), dims=['i'])
ind_j = xr.DataArray(range(dims[2]), dims=['j'])
ind_x = xr.DataArray(range(dims[1]), dims=['x'])
ind_y = xr.DataArray(range(dims[2]), dims=['y'])

t3 = time.time()
acorrb2[ind_i, ind_j, ind_x, ind_y]=spearmanr(a[:,ind_i,ind_j], b[:,ind_x,ind_y],dim='time')
t4 = time.time()
print(t4-t3) #0.07205533981323242

Over 5x faster.

print((acorrb.values==acorrb2.values).all()) #True

Upvotes: 0

aaron.spring
aaron.spring

Reputation: 96

You can speed up when you vectorise and not loop your code.

Try xskillscore with a vectorised and parallelised spearmanr function. https://xskillscore.readthedocs.io/en/stable/api/xskillscore.spearman_r.html#xskillscore.spearman_r

Upvotes: 0

erncnerky
erncnerky

Reputation: 404

You can use the below snippet to make your code parallel.

import time
import itertools
import multiprocessing

yt = range(2)
xt = range(2)
yp = range(2)
xp = range(2)

param_list = list(itertools.product(yt, xt, yp, xp))

def task(args):
    print(args)
    # task
    time.sleep(1)
    return args

pool = multiprocessing.Pool()

response = pool.map(task, param_list)
print(response)

Upvotes: 2

Related Questions