TrailDreaming
TrailDreaming

Reputation: 567

Why is the "map" version of ThreeSum so slow?

I expected this Python implementation of ThreeSum to be slow:

def count(a):
       """ThreeSum: Given N distinct integers, how many triples sum to exactly zero?"""
       N = len(a)
       cnt = 0
       for i in range(N):
         for j in range(i+1, N):
           for k in range(j+1, N):
             if sum([a[i], a[j], a[k]]) == 0:
               cnt += 1
       return cnt 

But I was shocked that this version looks pretty slow too:

def count_python(a):
  """ThreeSum using itertools"""
  return sum(map(lambda X: sum(X)==0, itertools.combinations(a, r=3))) 

Can anyone recommend a faster Python implementation? Both implementations just seem so slow... Thanks

...

ANSWER SUMMARY: Here is how the runs of all the various versions provided in this thread of the O(N^3) (for educational purposes, not used in real life) version worked out on my machine:

56 sec RUNNING count_slow...
28 sec RUNNING count_itertools, written by Ashwini Chaudhary...
14 sec RUNNING count_fixed, written by roippi...
11 sec RUNNING count_itertools (faster), written by Veedrak...
08 sec RUNNING count_enumerate, written by roippi...

*Note: Needed to modify Veedrak's solution to this to get the correct count output:
sum(1 for x, y, z in itertools.combinations(a, r=3) if x+y==-z)

Upvotes: 3

Views: 245

Answers (3)

roippi
roippi

Reputation: 25974

Supplying a second answer. From various comments, it looks like you're primarily concerned about why this particular O(n**3) algorithm is slow when being ported over from java. Let's dive in.

def count(a):
       """ThreeSum: Given N distinct integers, how many triples sum to exactly zero?"""
       N = len(a)
       cnt = 0
       for i in range(N):
         for j in range(i+1, N):
           for k in range(j+1, N):
             if sum([a[i], a[j], a[k]]) == 0:
               cnt += 1
       return cnt

One major problem that immediately pops out is that you're doing something your java code almost certainly isn't doing: materializing a 3-element list just to add three numbers together!

if sum([a[i], a[j], a[k]]) == 0:

Yuck! Just write that as

if a[i] + a[j] + a[k] == 0:

Some benchmarking shows that you're adding 50%+ overhead just by doing that. Yikes.


The other issue here is that you're using indexing where you should be using iteration. In python try to avoid writing code like this:

for i in range(len(some_list)):
    do_something(some_list[i])

And instead just write:

for x in some_list:
    do_something(x)

And if you explicitly need the index that you're on (as you actually do in your code), use enumerate:

for i,x in enumerate(some_list):
    #etc

This is, in general, a style thing (though it goes deeper than that, with duck typing and the iterator protocol) - but it is also a performance thing. In order to look up the value of a[i], that call is converted to a.__getitem__(i), then python has to dynamically resolve a __getitem__ method lookup, call it, and return the value. Every time. It's not a crazy amount of overhead - at least on builtin types - but it adds up if you're doing it a lot in a loop. Treating a as an iterable, on the other hand, sidesteps a lot of that overhead.

So taking that change in mind, you can rewrite your function once again:

def count_enumerate(a):
    cnt = 0
    for i, x in enumerate(a):
        for j, y in enumerate(a[i+1:], i+1):
            for z in a[j+1:]:
                if x + y + z == 0:
                    cnt += 1
    return cnt

Let's look at some timings:

%timeit count(range(-100,100))
1 loops, best of 3: 394 ms per loop

%timeit count_fixed(range(-100,100)) #just fixing your sum() line
10 loops, best of 3: 158 ms per loop

%timeit count_enumerate(range(-100,100))
10 loops, best of 3: 88.9 ms per loop

And that's about as fast as it's going to go. You can shave off a percent or so by wrapping everything in a comprehension instead of doing cnt += 1 but that's pretty minor.

I've toyed around with a few itertools implementations but I actually can't get them to go faster than this explicit loop version. This makes sense if you think about it - for every iteration, the itertools.combinations version has to rebind what all three variables refer to, whereas the explicit loops get to "cheat" and rebind the variables in the outer loops far less often.

Reality check time, though: after everything is said and done, you can still expect cPython to run this algorithm an order of magnitude slower than a modern JVM would. There is simply too much abstraction built in to python that gets in the way of looping quickly. If you care about speed (and you can't fix your algorithm - see my other answer), either use something like numpy to spend all of your time looping in C, or use a different implementation of python.


postscript: pypy

For fun, I ran count_fixed on a 1000-element list, on both cPython and pypy.

cPython:

In [81]: timeit.timeit('count_fixed(range(-500,500))', setup='from __main__ import count_fixed', number = 1)
Out[81]: 19.230753898620605

pypy:

>>>> timeit.timeit('count_fixed(range(-500,500))', setup='from __main__ import count_fixed', number = 1)
0.6961538791656494

Speedy!

I might add some java testing in later to compare :-)

Upvotes: 3

roippi
roippi

Reputation: 25974

Algorithmically, both versions of your function are O(n**3) - so asymptotically neither is superior. You will find that the itertools version is in practice somewhat faster since it spends more time looping in C rather than in python bytecode. You can get it down a few more percentage points by removing map entirely (especially if you're running py2) but it's still going to be "slow" compared to whatever times you got from running it in a JVM.

Note that there are plenty of python implementations other than cPython out there - for loopy code, pypy tends to be much faster than cPython. So I wouldn't write python-as-a-language off as being slow, necessarily, but I would certainly say that the reference implementation of python is not known for its blazing loop speed. Give other python flavors a shot if that's something you care about.

Specific to your algorithm, an optimization will let you drop it down to O(n**2). Build up a set of your integers, s, and build up all pairs (a,b). You know that you can "zero out" (a+b) if and only if -(a+b) in (s - {a,b}).

Thanks to @Veedrak: unfortunately constructing s - {a,b} is a slow O(len(s)) operation itself - so simply check if -(a+b) is equal to either a or b. If it is, you know there's no third c that can fulfill a+b+c == 0 since all numbers in your input are distinct.

def count_python_faster(a):
      s = frozenset(a)
      return sum(1 for x,y in itertools.combinations(a,2)
             if -(x+y) not in (x,y) and -(x+y) in s) // 3

Note the divide-by-three at the end; this is because each successful combination is triple-counted. It's possible to avoid that but it doesn't actually speed things up and (imo) just complicates the code.

Some timings for the curious:

%timeit count(range(-100,100))
1 loops, best of 3: 407 ms per loop

%timeit count_python(range(-100,100)) #this is about 100ms faster on py3
1 loops, best of 3: 382 ms per loop

%timeit count_python_faster(range(-100,100))
100 loops, best of 3: 5.37 ms per loop

Upvotes: 3

isedev
isedev

Reputation: 19641

You haven't stated which version of Python you're using.

In Python 3.x, a generator expression is around 10% faster than either of the two implementations you listed. Using a random array of 100 numbers in the range [-100,100] for a:

count(a)           -> 8.94 ms  # as per your implementation
count_python(a)    -> 8.75 ms  # as per your implementation

def count_generator(a):
    return sum((sum(x) == 0 for x in itertools.combinations(a,r=3)))

count_generator(a) -> 7.63 ms

But other than that, it's the shear amount of combinations that's dominating execution time - O(N^3).

I should add the times shown above are for loops of 10 calls each, averaged over 10 loops. And yeah, my laptop is slow too :)

Upvotes: 2

Related Questions