Reputation: 108
I'm trying to multiply three arrays (A x B x A), with the dimensions (19000, 3) x (19000, 3, 3) x (19000, 3) so that at the end I'm getting a 1d-array with the size (19000), so I want to multiply only along the last one/two dimensions.
I've got it working with np.einsum() but I'm wondering if there is any way of making this faster, as this is the bottleneck of my whole code.
np.einsum('...i,...ij,...j', A, B, A)
I've already tried it with two separated np.einsum() calls, but that gave me the same performance:
np.einsum('...i, ...i', np.einsum('...i,...ij', A, B), A)
As well I've already tried the @ operator and adding some additional axes, but that also didn't make it faster:
(A[:, None]@B@A[...,None]).squeeze()
I've tried to get it working with np.inner(), np.dot(), np.tensordot() and np.vdot(), but these never gave me the same results, so I couldn't compare them.
Any other ideas? Is there any way I could get a better performance?
I've already had a quick look at Numba, but as Numba doesn't support np.einsum() and many other NumPy functions, I would have to rewrite a lot of code.
Upvotes: 3
Views: 1317
Reputation: 6482
In the beginning it is always a good idea, to look what np.einsum does. With optimize==optimal
it is usually really good to find a way of contraction, which has less FLOPs. In this case there is actually only a minor optimization possible and the intermediate array is relatively large (I will stick to the naive version). It should also be mentioned that contractions with very small (fixed?) dimensions are a quite special case. This is also a reason why it is quite easy to outperfom np.einsum
here (unrolling etc..., which a compiler does if it knows that a loop consists only of 3 elements)
import numpy as np
A=np.random.rand(19000, 3)
B=np.random.rand(19000, 3, 3)
print(np.einsum_path('...i,...ij,...j', A, B, A,optimize="optimal")[1])
"""
Complete contraction: si,sij,sj->s
Naive scaling: 3
Optimized scaling: 3
Naive FLOP count: 5.130e+05
Optimized FLOP count: 4.560e+05
Theoretical speedup: 1.125
Largest intermediate: 5.700e+04 elements
--------------------------------------------------------------------------
scaling current remaining
--------------------------------------------------------------------------
3 sij,si->js sj,js->s
2 js,sj->s s->s
"""
Numba implementation
import numba as nb
#si,sij,sj->s
@nb.njit(fastmath=True,parallel=True,cache=True)
def nb_einsum(A,B):
#check the input's at the beginning
#I assume that the asserted shapes are always constant
#This makes it easier for the compiler to optimize
assert A.shape[1]==3
assert B.shape[1]==3
assert B.shape[2]==3
#allocate output
res=np.empty(A.shape[0],dtype=A.dtype)
for s in nb.prange(A.shape[0]):
#Using a syntax like that is also important for performance
acc=0
for i in range(3):
for j in range(3):
acc+=A[s,i]*B[s,i,j]*A[s,j]
res[s]=acc
return res
Timings
#warmup the first call is always slower
#(due to compilation or loading the cached function)
res=nb_einsum(A,B)
%timeit nb_einsum(A,B)
#43.2 µs ± 1.22 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit np.einsum('...i,...ij,...j', A, B, A,optimize=True)
#450 µs ± 8.28 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit np.einsum('...i,...ij,...j', A, B, A)
#977 µs ± 4.14 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
np.allclose(np.einsum('...i,...ij,...j', A, B, A,optimize=True),nb_einsum(A,B))
#True
Upvotes: 2