Splash
Splash

Reputation: 127

Python list comprehension performance

I did some experiments with two ways to initialize a two-dimensional list.

I tested on both my local Macbook Pro and Leetcode playground, the result shows the first method is 4-5 times faster than the second method.

Can anyone explain the performance lagging of list comprehension?

n = 999
t0 = time.time()
arr1 = [[None] * n for _ in range(n)]
t1 = time.time()
print(t1 - t0)

t2 = time.time()
arr2 = [[None for _ in range(n)] for _ in range(n)]
t3 = time.time()
print(t3 - t2)

Upvotes: 3

Views: 931

Answers (3)

Qinsheng Zhang
Qinsheng Zhang

Reputation: 1323

Just a toy experiment, and it can further speedup if creating arrays if element datatype is compatible with numpy

%timeit [[None] * n for _ in range(n)]
1.42 ms ± 6.84 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit [[None for _ in range(n)] for _ in range(n)]
17.3 ms ± 109 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit np.zeros((n,n))
148 µs ± 440 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Upvotes: 1

juanpa.arrivillaga
juanpa.arrivillaga

Reputation: 95948

Note, you are doing two different things. You meant to use:

[[None] * n for _ in range(n)]

You've wrapped your inner lists in an additional list, but that won't make a huge difference in the timing results. The list repetition version is definitely faster.

[None]*n is very fast, it allocates the underlying buffer exactly then does a C-level loop. [None for _ in range(n)] is a python level loop that uses append, which is amortized constant time but will involve buffer re-allocations.

Just looking at the bytecode gives a hint:

>>> import dis
>>> dis.dis('[None]*n')
  1           0 LOAD_CONST               0 (None)
              2 BUILD_LIST               1
              4 LOAD_NAME                0 (n)
              6 BINARY_MULTIPLY
              8 RETURN_VALUE

Basically, all the work is done in BINARY_MULTIPLY. For the list comprehension:

>>> dis.dis("[None for _ in range(n)]")
  1           0 LOAD_CONST               0 (<code object <listcomp> at 0x7fc06e31bea0, file "<dis>", line 1>)
              2 LOAD_CONST               1 ('<listcomp>')
              4 MAKE_FUNCTION            0
              6 LOAD_NAME                0 (range)
              8 LOAD_NAME                1 (n)
             10 CALL_FUNCTION            1
             12 GET_ITER
             14 CALL_FUNCTION            1
             16 RETURN_VALUE

Disassembly of <code object <listcomp> at 0x7fc06e31bea0, file "<dis>", line 1>:
  1           0 BUILD_LIST               0
              2 LOAD_FAST                0 (.0)
        >>    4 FOR_ITER                 8 (to 14)
              6 STORE_FAST               1 (_)
              8 LOAD_CONST               0 (None)
             10 LIST_APPEND              2
             12 JUMP_ABSOLUTE            4
        >>   14 RETURN_VALUE
>>>

The looping work is done at the Python interpreter level. Also, it grows the list through .append, which is algorithmically efficient, but will still be slower than what is done by list repetition, which is all pushed into the C layer.

Here is the C source code:

https://github.com/python/cpython/blob/48ed88a93bb0bbeaae9a4cfaa533e4edf13bcb51/Objects/listobject.c#L504

As you can see, it allocates the underlying buffer to the exact size it needs:

np = (PyListObject *) PyList_New(size);

Then, it does a quick loop, filling up the buffer without re-allocations. The most general case:

p = np->ob_item;
items = a->ob_item;
for (i = 0; i < n; i++) {
    for (j = 0; j < Py_SIZE(a); j++) {
        *p = items[j];
        Py_INCREF(*p);
        p++;
    }
}

Upvotes: 7

RufusVS
RufusVS

Reputation: 4127

I noticed the different techniques were generating different structures (as I noted in a comment) This rewrite yields the same structures, but as @juanpa arrivillaga pointed out, you are actually getting multiple references to a single list, which will show up when you start assigning values to array elements.

import time
from pprint import pprint

n = 999 # for time test

# n = 5 # for structure printout test.

t0 = time.time()
arr1 = [[None] * n] * n
t1 = time.time()
print(t1 - t0)

t0 = time.time()
arr2 = [[None] * n for _ in range(n)]
t1 = time.time()
print(t1 - t0)

t2 = time.time()
arr3 = [[None for _ in range(n)] for _ in range(n)]
t3 = time.time()
print(t3 - t2)

if n<20:
    print (len(arr1),len(arr1[0]))
    pprint (arr1)
    print (len(arr2),len(arr2[0]))
    pprint (arr2)
    print (len(arr3),len(arr3[0]))
    pprint (arr3)

Upvotes: -1

Related Questions