Reputation: 725
I'm looking for an efficient way to multiply a list of matrices in Numpy. I have a matrix like this:
import numpy as np
a = np.random.randn(1000, 4, 4)
I want to matrix-multiply along the long axis, so the result is a 4x4 matrix. So clearly I can do:
res = np.identity(4)
for ai in a:
res = np.matmul(res, ai)
But this is super-slow. Is there a faster way (perhaps using einsum
or some other function that I don't fully understand yet)?
Upvotes: 4
Views: 6180
Reputation: 36859
A solution that requires log_2(n)
for
loop interations for stacks with size of powers of 2 could be
while len(a) > 1:
a = np.matmul(a[::2, ...], a[1::2, ...])
which essentially iteratively multiplies two neighbouring matrices together until there is only one matrix left, doing half of the remaining multiplications per iteration.
res = A * B * C * D * ... # 1024 remaining multiplications
becomes
res = (A * B) * (C * D) * ... # 512 remaining multiplications
becomes
res = ((A * B) * (C * D)) * ... # 256 remaining multiplications
etc.
For non-powers of 2 you can do this for the first 2^n
matrices and use your algorithm for the remaining matrices.
Upvotes: 4
Reputation: 231738
np.linalg.multi_dot
does this sort of chaining.
In [119]: a = np.random.randn(5, 4, 4)
In [120]: res = np.identity(4)
In [121]: for ai in a: res = np.matmul(res, ai)
In [122]: res
Out[122]:
array([[ -1.04341835, -1.22015464, 9.21459712, 0.97214725],
[ -0.13652679, 0.61012689, -0.07325689, -0.17834132],
[ -2.45684401, -1.76347514, 12.41094524, 1.00411347],
[ -8.36738671, -6.5010718 , 15.32489832, 3.62426123]])
In [123]: np.linalg.multi_dot(a)
Out[123]:
array([[ -1.04341835, -1.22015464, 9.21459712, 0.97214725],
[ -0.13652679, 0.61012689, -0.07325689, -0.17834132],
[ -2.45684401, -1.76347514, 12.41094524, 1.00411347],
[ -8.36738671, -6.5010718 , 15.32489832, 3.62426123]])
But it is slower, 92.3 µs per loop v 22.2 µs per loop. And for your 1000 item case, the test timing is still running.
After figuring out some 'optimal order' multi_dot
does a recursive dot
.
def _multi_dot(arrays, order, i, j):
"""Actually do the multiplication with the given order."""
if i == j:
return arrays[i]
else:
return dot(_multi_dot(arrays, order, i, order[i, j]),
_multi_dot(arrays, order, order[i, j] + 1, j))
In the 1000 item case this hits a recursion depth error.
Upvotes: 2