Jack Fu
Jack Fu

Reputation: 43

optimize python counting in a three-iteration statement

all,

I try to count the number of items that meet the following conditions house id==m and transformed day segment id == n and neighborhood function == m where house ids are represented by docs['house_id'], day segment ids are represented by docs['transformed_dayseg_id'], and neighborhood functions are represented by self.CF/self.TF/self.BF.

For this purpose, I perform counting in terms of the following codes. However, it is too slow. Any idea about optimizing the python counting codes?

def get_mnf_counter(self, docs, dtype):
    x = np.zeros((self.M,self.N,self.K))
    for m in range(self.M):
        for n in range(self.N):
            for k in range(self.K):
                if dtype==1: #checkin
                    x[m, n, k]=sum((np.array(docs['house_id'])==m) & (np.array(docs['transformed_dayseg_id'])==n) & (self.CF==k))
                elif dtype==2: #taxi
                    x[m, n, k]=sum((np.array(docs['house_id'])==m) & (np.array(docs['transformed_dayseg_id'])==n) & (self.TF==k))
                elif dtpe==3: #bus
                    x[m, n, k]=sum((np.array(docs['house_id'])==m) & (np.array(docs['transformed_dayseg_id'])==n) & (self.BF==k))
                else:
                    raise Exception("index of checkin/taxi/bus/ is wrong")
    return x

Upvotes: 1

Views: 59

Answers (1)

thefourtheye
thefourtheye

Reputation: 239653

You can use itertools.product, like this

from itertools import product
...
...

for m, n, k in product(range(self.M), range(self.N), range(self.K)):
   ...

More good news is, you can optimize your code even more, like this

from itertools import product
def get_mnf_counter(self, docs, dtype):
    x = np.zeros((self.M, self.N, self.K))

    if dtype not in (1, 2, 3):
        raise Exception("index of checkin/taxi/bus/ is wrong")

    if dtype == 1:
        value = self.CF
    elif dtype == 2:
        value = self.TF
    else:
        value = self.BF

    house_id = np.array(docs['house_id'])
    dayseg_id = np.array(docs['transformed_dayseg_id'])

    for m, n, k in product(range(self.M), range(self.N), range(self.K)):
        x[m, n, k] = sum((house_id == m) & (dayseg_id == n) & (value == k))

Since your algorithm is O(N ^ 3), there is nothing much you can do here.

Upvotes: 2

Related Questions