Yotam Vaknin
Yotam Vaknin

Reputation: 676

Fast multiplication of a series of matrices

What is the fastest possible way to run:

    reduce(lambda x,y : x@y, ls)

in python?

for a list of matrices ls. I don't have an Nvidia GPU, but I do have a lot of CPU cores to work with. I thought I could make the process work in parallel (split it to log iterations), but it seems that for small (1000x1000) matrix, this is actually worst. Here is the code I tried:

from multiprocessing import Pool
import numpy as np
from itertools import zip_longest

def matmul(x):
    if x[1] is None:
        return x[0]
    return x[1]@x[0]

def fast_mul(ls):
    while True:
        
        n = len(ls)
        if n == 0:
            raise Exception("Splitting Error")
        if n == 1:
            return ls[0]
        if n == 2:
            return ls[1]@ls[0]

        with Pool(processes=(n//2+1)) as pool:
            ls = pool.map(matmul, list(zip_longest(*[iter(ls)]*2)))
    

Upvotes: 4

Views: 400

Answers (2)

javidcf
javidcf

Reputation: 59681

EDIT: Threw in yet another possible function

EDIT: I added the results with np.linalg.multi_dot, expecting it would be faster than the rest but actually it is much slower somehow. I suppose it is design with other kind of use case in mind.


I'm not sure you will be able to get much faster than that. Here are a few different implementations of the reduction for the case where the data is a 3D array of square matrices:

from multiprocessing import Pool
from functools import reduce
import numpy as np
import numba as nb

def matmul_n_naive(data):
    return reduce(np.matmul, data)

# If you don't care about modifying data pass copy=False
def matmul_n_binary(data, copy=True):
    if len(data) < 1:
        raise ValueError
    data = np.array(data, copy=copy)
    n, r, c = data.shape
    dt = data.dtype
    s = 1
    while (n + s - 1) // s > 1:
        a = data[:n - s:2 * s]
        b = data[s:n:2 * s]
        np.matmul(a, b, out=a)
        s *= 2
    return np.array(a[0])

def matmul_n_pool(data):
    if len(data) < 1:
        raise ValueError
    lst = data
    with Pool() as pool:
        while len(lst) > 1:
            lst_next = pool.starmap(np.matmul, zip(lst[::2], lst[1::2]))
            if len(lst) % 2 != 0:
                lst_next.append(lst[-1])
            lst = lst_next
    return lst[0]

@nb.njit(parallel=False)
def matmul_n_numba_nopar(data):
    res = np.eye(data.shape[1], data.shape[2], dtype=data.dtype)
    for i in nb.prange(len(data)):
        res = res @ data[i]
    return res

@nb.njit(parallel=True)
def matmul_n_numba_par(data):
    res = np.eye(data.shape[1], data.shape[2], dtype=data.dtype)
    for i in nb.prange(len(data)):  # Numba knows how to do parallel reductions correctly
        res = res @ data[i]
    return res

def matmul_n_multidot(data):
    return np.linalg.multi_dot(data)

And a test:

# Test
import numpy as np

np.random.seed(0)
a = np.random.rand(10, 100, 100) * 2 - 1
b1 = matmul_n_naive(a)
b2 = matmul_n_binary(a)
b3 = matmul_n_pool(a)
b4 = matmul_n_numba_nopar(a)
b5 = matmul_n_numba_par(a)
b6 = matmul_n_multidot(a)
print(np.allclose(b1, b2))
# True
print(np.allclose(b1, b3))
# True
print(np.allclose(b1, b4))
# True
print(np.allclose(b1, b5))
# True
print(np.allclose(b1, b6))
# True

Here are some benchmarks, it seems there is no consistent winner but the "naive" solution is pretty good all around, binary and Numba vary, the process pool is not really good and np.linalg.multi_dot does not seem to be very advantageous with square matrices.

import numpy as np

# 10 matrices 1000x1000
np.random.seed(0)
a = np.random.rand(10, 1000, 1000) * 0.1 - 0.05
%timeit matmul_n_naive(a)
# 121 ms ± 6.09 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit matmul_n_binary(a)
# 165 ms ± 3.68 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit matmul_n_numba_nopar(a)
# 108 ms ± 510 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit matmul_n_numba_par(a)
# 244 ms ± 7.66 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit matmul_n_multidot(a)
# 132 ms ± 2.41 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

# 200 matrices 100x100
np.random.seed(0)
a = np.random.rand(200, 100, 100) * 0.1 - 0.05
%timeit matmul_n_naive(a)
# 4.4 ms ± 226 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit matmul_n_binary(a)
# 13.4 ms ± 299 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit matmul_n_numba_nopar(a)
# 9.51 ms ± 126 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit matmul_n_numba_par(a)
# 4.93 ms ± 146 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit matmul_n_multidot(a)
# 1.14 s ± 22.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

# 300 matrices 10x10
np.random.seed(0)
a = np.random.rand(300, 10, 10) * 0.1 - 0.05
%timeit matmul_n_naive(a)
# 526 µs ± 953 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit matmul_n_binary(a)
# 152 µs ± 508 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit matmul_n_pool(a)
# 610 ms ± 5.93 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit matmul_n_numba_nopar(a)
# 239 µs ± 1.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit matmul_n_numba_par(a)
# 175 µs ± 422 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit matmul_n_multidot(a)
# 3.68 s ± 87 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

# 1000 matrices 10x10
np.random.seed(0)
a = np.random.rand(1000, 10, 10) * 0.1 - 0.05
%timeit matmul_n_naive(a)
# 1.56 ms ± 4.49 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit matmul_n_binary(a)
# 392 µs ± 790 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit matmul_n_pool(a)
# 727 ms ± 12.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit matmul_n_numba_nopar(a)
# 589 µs ± 356 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit matmul_n_numba_par(a)
# 451 µs ± 1.68 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit matmul_n_multidot(a)
# Never finished...

Upvotes: 2

Mad Physicist
Mad Physicist

Reputation: 114230

There is a function to do this: np.linalg.multi_dot, supposedly optimized for the best evaluation order:

np.linalg.multi_dot(ls)

In fact the docs say something very close to your original phrasing:

Think of multi_dot as:

def multi_dot(arrays): return functools.reduce(np.dot, arrays)

You could also try np.einsum, which will allow you to multiply up to 25 matrices:

from string import ascii_lowercase

ls = [...]
index = ','.join(ascii_lowercase[x:x + 2] for x in range(len(ls)))
index += f'->{index[0]}{index[-1]}'
np.einsum(index, *ls)

Timing

Simple case:

ls = np.random.rand(100, 1000, 1000) - 0.5

%timeit reduce(lambda x, y : x @ y, ls)
4.3 s ± 76.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit reduce(np.matmul, ls)
4.35 s ± 84.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit reduce(np.dot, ls)
4.86 s ± 68.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit np.linalg.multi_dot(ls)
5.24 s ± 66.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

More complicated case:

ls = [x.T if i % 2 else x for i, x in enumerate(np.random.rand(100, 2000, 500) - 0.5)]

%timeit reduce(lambda x, y : x @ y, ls)
7.94 s ± 96.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit reduce(np.matmul, ls)
7.91 s ± 33.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit reduce(np.dot, ls)
9.38 s ± 111 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit np.linalg.multi_dot(ls)
2.03 s ± 52.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Notice that the up-front work done by multi_dot has negative benefit in the straightforward case (and more suprisingly, lambda works faster than the raw operator), but saves 75% of the time in the less straightforward case.

So just for completeness, here is a less non-square case:

ls = [x.T if i % 2 else x for i, x in enumerate(np.random.rand(100, 400, 300) - 0.5)]

%timeit reduce(lambda x, y : x @ y, ls)
245 ms ± 8.18 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit reduce(np.matmul, ls)
245 ms ± 12.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit reduce(np.dot, ls)
284 ms ± 12.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit np.linalg.multi_dot(ls)
638 ms ± 12.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

So really it seems that for most general cases, your original reduce call is actually about as good as you need to get. My only suggestion would be to use operator.matmul instead of the lambda.

Upvotes: 2

Related Questions