tacos_life
tacos_life

Reputation: 13

How to remove nested for loop?

I have the following nested loop:

sum_tot = 0.0

for i in range(len(N)-1):
    for j in range(len(N)-1):
        sum_tot = sum_tot + N[i]**2*N[j]**2*W[i]*W[j]*x_i[j][-1] / (N[j]**2 - x0**2) *(z_i[i][j] - z_j[i][j])*x_j[i][-1] / (N[i]**2 - x0**2)

It's basically a mathematical function that has a double summation. Each sum goes up to the length of N. I've been trying to figure out if there was a way to write this without using a nested for-loop in order to reduce computational time. I tried using list comprehension, but the computational time is similar if not the same. Is there a way to write this expression as matrices to avoid the loops?

Upvotes: 1

Views: 2952

Answers (3)

Akshay Sehgal
Akshay Sehgal

Reputation: 19322

@Kraigolas makes valid points. But let's try a few benchmarks on a dummy, double nested operation, either way. (Hint: Numba might help you speed things up)

Note, I would avoid numpy arrays specifically because all of the cross-product between the range is going to be in memory at once. If this is a massive range, you may run out of memory.

Nested for loops

n = 5000
s1 = 0

for i in range(n):
    for j in range(n):
        s1 += (i/2) + (j/3)
        
print(s1)

#2.26 s ± 101 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

List comprehension

n = 5000
s2 = 0

s2 = sum([i/2+j/3 for i in range(n) for j in range(n)])
print(s2)

#3.2 s ± 307 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Itertools product

from itertools import product

n = 5000
s3 = 0

for i,j in product(range(n),repeat=2):
    s3 += (i/2) + (j/3)
print(s3)

#2.35 s ± 186 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Note: When using Numba, you would want to run the code at least once before, because the first time it compiles the code and therefore the speed is slow. The real speedup comes second run onwards.

Numba njit (SIMD)

from numba import njit

n=5000

@njit
def f(n):
    s = 0
    for i in range(n):
        for j in range(n):
            s += (i/2) + (j/3)
    return s

s4 = f(n)

#29.4 ms ± 1.85 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

Numba njit parallel with prange

An excellent suggestion by @Tim, added to benchmarks

@njit(parallel=True)
def f(n):
    s = 0
    for i in prange(n):
        for j in prange(n):
            s += (i/2) + (j/3)
    return s

s5 = f(n)

#21.8 ms ± 4.81 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Significant boost up with Numba as expected. Maybe try that?

Upvotes: 2

Tim
Tim

Reputation: 3417

To convert this to matrix calculations, I would suggest combine some terms first.

If these objects are not numpy arrays, it's better to convert them to numpy arrays, as they support element-wise operations.

To convert, simply do

import numpy
N = numpy.array(N)
w = numpy.array(w)
x_i = numpy.array(x_i)
x_j = numpy.array(x_j)
z_i = numpy.array(z_i)
z_j = numpy.array(z_j)

Then,


common_terms = N**2*w/(N**2-x0**2)
i_terms = common_terms*x_j[:,-1]
j_terms = common_terms*x_i[:,-1]
i_j_matrix = z_i - z_j
sum_output = (i_terms.reshape((1,-1)) @ i_j_matrix @ j_terms.reshape((-1,1)))[0,0]

Upvotes: 0

Kraigolas
Kraigolas

Reputation: 5570

Note that range will stop at N-2 given your current loop: range goes up to but not including its argument. You probably mean to write for i in range(len(N)).

It's also difficult to reduce summation: the actual time it takes is based on the number of terms computed, so if you write it a different way which still involves the same number of terms, it will take just as long. However, O(n^2) isn't exactly bad: it looks like the best you can do in this situation unless you find a mathematical simplification of the problem.

You might consider checking this post to gather ways to write out the summation in a neater fashion.

Upvotes: 2

Related Questions