N. Gast
N. Gast

Reputation: 192

Numba and numpy array allocation: why is it so slow?

I recently played with Cython and Numba to accelerate small pieces of a python that does numerical simulation. At first, developing with numba seems easier. Yet, I found difficult to understand when numba will provide a better performance and when it will not.

One example of unexpected performance drop is when I use the function np.zeros() to allocate a big array in a compiled function. For example, consider the three function definitions:

import numpy as np 
from numba import jit 

def pure_python(n):
    mat = np.zeros((n,n), dtype=np.double)
    # do something
    return mat.reshape((n**2))

@jit(nopython=True)
def pure_numba(n):
    mat = np.zeros((n,n), dtype=np.double)
    # do something
    return mat.reshape((n**2))

def mixed_numba1(n):
    return mixed_numba2(np.zeros((n,n)))

@jit(nopython=True)

def mixed_numba2(array):
    n = len(array)
    # do something
    return array.reshape((n,n))

# To compile 
pure_numba(10)
mixed_numba1(10)

Since the #do something is empty, I do not expect the pure_numba function to be faster. Yet, I was not expecting such a performance drop:

n=10000
%timeit x = pure_python(n)
%timeit x = pure_numba(n)
%timeit x = mixed_numba1(n)

I obtain (python 3.7.7, numba 0.48.0 on a mac)

4.96 µs ± 65.9 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
344 ms ± 7.76 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
3.8 µs ± 30.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

Here, the numba code is much slower when I use the function np.zeros() inside the compiled function. It works normally when the np.zeros() is outside the function.

Am I doing something wrong here or should I always allocate big arrays like these outside functions that are compiled by numba?

Update

This seems related to a lazy initialization of the matrices by np.zeros((n,n)) when n is large enough (see Performance of zeros function in Numpy ).

for n in [1000, 2000, 5000]:
    print('n=',n)
    %timeit x = pure_python(n)
    %timeit x = pure_numba(n)
    %timeit x = mixed_numba1(n)

gives me:

n = 1000
468 µs ± 15.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
296 µs ± 6.55 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
300 µs ± 2.26 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
n = 2000
4.79 ms ± 182 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
4.45 ms ± 36 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
4.54 ms ± 127 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
n = 5000
270 µs ± 4.66 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
104 ms ± 599 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
119 µs ± 1.24 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Upvotes: 6

Views: 1935

Answers (1)

Jacques Gaudin
Jacques Gaudin

Reputation: 16958

tl;dr Numpy uses C memory functions whereas Numba must assign zeros

I wrote a script to plot the time it takes for several options to complete and it appears that Numba has a severe drop in performance when the size of the np.zeros array reaches 2048*2048*8 = 32 MB on my machine as shown in the diagram below.

Numba's implementation of np.zeros is just as fast as creating an empty array and filling it with zeros by iterating over the dimensions of the array (this is the Numba nested loop green curve of the diagram). This can actually be double-checked by setting the NUMBA_DUMP_IR environment variable before running the script (see below). When comparing to the dump for numba_loop there is not much difference.

Interstingly, np.zeros gets a little boost passed the 32 MB threshold.

My best guess, although I am far from an expert, is that the 32 MB limit is an OS or hardware bottleneck coming from the amount of data that can fit in a cache for the same process. If this is exceeded, the operation of moving data in and out of the cache to operate on it is very time consuming.

By contrast, Numpy uses calloc to get some memory segment with a promise to fill the data with zeros when it will be accessed.

This is how far I got and I realise it's only half an answer but maybe someone more knowledgeable can shed some light on what is actually going on.

Graph of time deltas for different options

Numba IR dump:

---------------------------IR DUMP: pure_numba_zeros----------------------------
label 0:
    n = arg(0, name=n)                       ['n']
    $2load_global.0 = global(np: <module 'numpy' from '/lib/python3.8/site-packages/numpy/__init__.py'>) ['$2load_global.0']
    $4load_attr.1 = getattr(value=$2load_global.0, attr=zeros) ['$2load_global.0', '$4load_attr.1']
    del $2load_global.0                      []
    $10build_tuple.4 = build_tuple(items=[Var(n, script.py:15), Var(n, script.py:15)]) ['$10build_tuple.4', 'n', 'n']
    $12load_global.5 = global(np: <module 'numpy' from '/lib/python3.8/site-packages/numpy/__init__.py'>) ['$12load_global.5']
    $14load_attr.6 = getattr(value=$12load_global.5, attr=double) ['$12load_global.5', '$14load_attr.6']
    del $12load_global.5                     []
    $18call_function_kw.8 = call $4load_attr.1($10build_tuple.4, func=$4load_attr.1, args=[Var($10build_tuple.4, script.py:15)], kws=[('dtype', Var($14load_attr.6, script.py:15))], vararg=None) ['$10build_tuple.4', '$14load_attr.6', '$18call_function_kw.8', '$4load_attr.1']
    del $4load_attr.1                        []
    del $14load_attr.6                       []
    del $10build_tuple.4                     []
    mat = $18call_function_kw.8              ['$18call_function_kw.8', 'mat']
    del $18call_function_kw.8                []
    $24load_method.10 = getattr(value=mat, attr=reshape) ['$24load_method.10', 'mat']
    del mat                                  []
    $const28.12 = const(int, 2)              ['$const28.12']
    $30binary_power.13 = n ** $const28.12    ['$30binary_power.13', '$const28.12', 'n']
    del n                                    []
    del $const28.12                          []
    $32call_method.14 = call $24load_method.10($30binary_power.13, func=$24load_method.10, args=[Var($30binary_power.13, script.py:16)], kws=(), vararg=None) ['$24load_method.10', '$30binary_power.13', '$32call_method.14']
    del $30binary_power.13                   []
    del $24load_method.10                    []
    $34return_value.15 = cast(value=$32call_method.14) ['$32call_method.14', '$34return_value.15']
    del $32call_method.14                    []
    return $34return_value.15                ['$34return_value.15']

The script to produce the diagram:

import numpy as np
from numba import jit
from time import time
import os
import matplotlib.pyplot as plt

os.environ['NUMBA_DUMP_IR'] = '1'

def numpy_zeros(n):
    mat = np.zeros((n,n), dtype=np.double)
    return mat.reshape((n**2))

@jit(nopython=True)
def numba_zeros(n):
    mat = np.zeros((n,n), dtype=np.double)
    return mat.reshape((n**2))

@jit(nopython=True)
def numba_loop(n):
    mat = np.empty((n * 2,n), dtype=np.float32)
    for i in range(mat.shape[0]):
        for j in range(mat.shape[1]):
            mat[i, j] = 0.
    return mat.reshape((2 * n**2))

# To compile
numba_zeros(10)
numba_loop(10)

os.environ['NUMBA_DUMP_IR'] = '0'

max_n = 4100
time_deltas = {
    'numpy_zeros': [],
    'numba_zeros': [],
    'numba_loop': [],
}
call_count = 10
for n in range(0, max_n, 10):
    for f in (numpy_zeros, numba_zeros, numba_loop):
        start = time()
        for i in range(call_count):
              x = f(n)
        delta = time() - start
        time_deltas[f.__name__].append(delta / call_count)
        print(f'{f.__name__:25} n = {n}: {delta}')
    print()

size = np.arange(0, max_n, 10) ** 2 * 8 / 1024 ** 2
fig, ax = plt.subplots()
plt.xticks(np.arange(0, size[-1], 16))
plt.axvline(x=32, color='gray', lw=0.5)
ax.plot(size, time_deltas['numpy_zeros'], label='Numpy zeros (calloc)')
ax.plot(size, time_deltas['numba_zeros'], label='Numba zeros')
ax.plot(size, time_deltas['numba_loop'], label='Numba nested loop')
ax.set_xlabel('Size of array in MB')
ax.set_ylabel(r'Mean $\Delta$t in s')
plt.legend(loc='upper left')
plt.show()

Upvotes: 8

Related Questions