fake823
fake823

Reputation: 108

Any chance of making this faster? (numpy.einsum)

I'm trying to multiply three arrays (A x B x A), with the dimensions (19000, 3) x (19000, 3, 3) x (19000, 3) so that at the end I'm getting a 1d-array with the size (19000), so I want to multiply only along the last one/two dimensions.

I've got it working with np.einsum() but I'm wondering if there is any way of making this faster, as this is the bottleneck of my whole code.

np.einsum('...i,...ij,...j', A, B, A)

I've already tried it with two separated np.einsum() calls, but that gave me the same performance:

np.einsum('...i, ...i', np.einsum('...i,...ij', A, B), A)

As well I've already tried the @ operator and adding some additional axes, but that also didn't make it faster:

(A[:, None]@B@A[...,None]).squeeze()

I've tried to get it working with np.inner(), np.dot(), np.tensordot() and np.vdot(), but these never gave me the same results, so I couldn't compare them.

Any other ideas? Is there any way I could get a better performance?

I've already had a quick look at Numba, but as Numba doesn't support np.einsum() and many other NumPy functions, I would have to rewrite a lot of code.

Upvotes: 3

Views: 1317

Answers (1)

max9111
max9111

Reputation: 6482

You could use Numba

In the beginning it is always a good idea, to look what np.einsum does. With optimize==optimal it is usually really good to find a way of contraction, which has less FLOPs. In this case there is actually only a minor optimization possible and the intermediate array is relatively large (I will stick to the naive version). It should also be mentioned that contractions with very small (fixed?) dimensions are a quite special case. This is also a reason why it is quite easy to outperfom np.einsum here (unrolling etc..., which a compiler does if it knows that a loop consists only of 3 elements)

import numpy as np

A=np.random.rand(19000, 3)
B=np.random.rand(19000, 3, 3)

print(np.einsum_path('...i,...ij,...j', A, B, A,optimize="optimal")[1])

"""
  Complete contraction:  si,sij,sj->s
         Naive scaling:  3
     Optimized scaling:  3
      Naive FLOP count:  5.130e+05
  Optimized FLOP count:  4.560e+05
   Theoretical speedup:  1.125
  Largest intermediate:  5.700e+04 elements
--------------------------------------------------------------------------
scaling                  current                                remaining
--------------------------------------------------------------------------
   3                  sij,si->js                                 sj,js->s
   2                    js,sj->s                                     s->s

"""

Numba implementation

import numba as nb

#si,sij,sj->s
@nb.njit(fastmath=True,parallel=True,cache=True)
def nb_einsum(A,B):
    #check the input's at the beginning
    #I assume that the asserted shapes are always constant
    #This makes it easier for the compiler to optimize 
    assert A.shape[1]==3
    assert B.shape[1]==3
    assert B.shape[2]==3

    #allocate output
    res=np.empty(A.shape[0],dtype=A.dtype)

    for s in nb.prange(A.shape[0]):
        #Using a syntax like that is also important for performance
        acc=0
        for i in range(3):
            for j in range(3):
                acc+=A[s,i]*B[s,i,j]*A[s,j]
        res[s]=acc
    return res

Timings

#warmup the first call is always slower 
#(due to compilation or loading the cached function)
res=nb_einsum(A,B)
%timeit nb_einsum(A,B)
#43.2 µs ± 1.22 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit np.einsum('...i,...ij,...j', A, B, A,optimize=True)
#450 µs ± 8.28 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit np.einsum('...i,...ij,...j', A, B, A)
#977 µs ± 4.14 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
np.allclose(np.einsum('...i,...ij,...j', A, B, A,optimize=True),nb_einsum(A,B))
#True

Upvotes: 2

Related Questions