Elkan
Elkan

Reputation: 616

Fastest way to count element in a list satisfying conditions

I want to get the number of elements in a list satisfying certain conditions specified by another list. The way I did is using sum and any. The simple testing codes are:

>>> x1 = list(xrange(300))
>>> x2 = [random.randrange(20, 50) for i in xrange(30)]
>>> def test():
        ns = []
        for i in xrange(10000):
            ns.append(sum(1 for j in x2 if any(abs(k-j)<=10 for k in x1)))
        return ns

Using Profiler shows that the sum and any caused the most time, any way to improve this?

>>> cProfile.run('ns = test()')
     8120003 function calls in 0.699 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.003    0.003    1.552    1.552 <pyshell#678>:2(test)
   310000    0.139    0.000    1.532    0.000 <pyshell#678>:5(<genexpr>)
        1    0.000    0.000    1.552    1.552 <string>:1(<module>)
  7490000    0.196    0.000    0.196    0.000 {abs}
   300000    0.345    0.000    1.377    0.000 {any}
    10000    0.001    0.000    0.001    0.000 {method 'append' of 'list' objects}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
    10000    0.016    0.000    1.548    0.000 {sum}

The function test only contain 10000 iterations. Generally, I would have tens of thousands iterations, and using the cProfile.run shows that this block cause majority of executing time.

===================================================================

Edit

According to @DavisHerring's answer, using binary search.

from _bisect import *
>>> x1 = list(xrange(300))
>>> x2 = [random.randrange(20, 50) for i in xrange(30)]
>>> def testx():
        ns = []
        x2k = sorted(x2)
        x1k = sorted(x1)
        for i in xrange(10000):
            bx = [bisect_left(x1k, xk) for xk in x2k]
            rn = sum(1 if k==0 and x1k[k]-xk<=10
                 else 1 if k==len(x1k) and xk-x1k[k-1]<=10
                 else xk-x1k[k-1]<=10 or x1k[k]-xk<=10
                 for k, xk in zip(bx, x2k))
            ns.append(rn)
        return ns

According to cProfile.run, 0.196 seconds is reached, 3x+ faster.

Upvotes: 4

Views: 187

Answers (2)

Bolo
Bolo

Reputation: 11690

The code

Use an interval tree data structure. A very simple implementation suitable for your needs could be as follows:

class SimpleIntervalTree:
    def __init__(self, points, radius):
        intervals = []
        l, r = None, None
        for p in sorted(points):
            if r is None or r < p - radius:
                if r is not None:
                    intervals.append((l, r))
                l = p - radius
            r = p + radius
        if r is not None:
            intervals.append((l, r))
        self._tree = self._to_tree(intervals)

    def _to_tree(self, intervals):
        if len(intervals) == 0:
            return None
        i = len(intervals) // 2
        return {
            'left': self._to_tree(intervals[0:i]),
            'value': intervals[i],
            'right': self._to_tree(intervals[i + 1:])
        }

    def __contains__(self, item):
        t = self._tree
        while t is not None:
            l, r = t['value']
            if item < l:
                t = t['left']
            elif item > r:
                t = t['right']
            else:
                return True
        return False

Then your code would look like this:

x1 = list(range(300))
x2 = [random.randrange(20, 50) for i in range(30)]
it = SimpleIntervalTree(x1, 10)
def test():
    ns = []
    for i in range(10000):
        ns.append(sum(1 for j in x2 if j in it))
    return ns

What the code does

In __init__, the list of points is first transformed to a list of contiguous intervals. Next, the intervals are put in a balanced binary search tree. In that tree, each node contains an interval, each left subtree of a node contains lower intervals and each right subtree of a node contains higher intervals. This way, whenever we want to test whether a point is in any of the segments (__contains__), we perform binary search starting from the root.

Upvotes: 0

Davis Herring
Davis Herring

Reputation: 39778

The nature of your predicate is critical; because it is a distance along a line, you can give your data a corresponding structure to speed the search. There are several variations:

Sort the list x1: then you can use binary search to find the nearest values and check whether they’re close enough.

If the list x2 is much longer, and most of its elements are not in range, you can make it a bit faster by sorting it instead and searching for the beginning and end of each acceptable interval.

If you sort both lists, you can step through them together and do it in linear time. This is asymptotically equivalent unless there’s another reason to sort them, of course.

Upvotes: 2

Related Questions