Reputation: 75
I have an 8x8x25000 array W and an 8 x 25000 array r. I want to multiple each 8x8 slice of W by each column (8x1) of r and save the result in Wres, which will end up being an 8x25000 matrix.
I am accomplishing this using a for loop as such:
for i in range(0,25000):
Wres[:,i] = np.matmul(W[:,:,i],res[:,i])
But this is slow and I am hoping there is a quicker way to accomplish this.
Any ideas?
Upvotes: 3
Views: 249
Reputation: 40888
An alternative to using np.matmul
is np.einsum
, which can be accomplished in 1 shorter and arguably more palatable line of code with no method chaining.
Example arrays:
np.random.seed(123)
w = np.random.rand(8,8,25000)
r = np.random.rand(8,25000)
wres = np.einsum('ijk,jk->ik',w,r)
# a quick check on result equivalency to your loop
print(np.allclose(np.matmul(w[:, :, 1], r[:, 1]), wres[:, 1]))
True
Timing is equivalent to @Imanol's solution so take your pick of the two. Both are 30x faster than looping. Here, einsum
will be competitive because of the size of the arrays. With arrays larger than these, it would likely win out, and lose for smaller arrays. See this discussion for more.
def solution1():
return np.einsum('ijk,jk->ik',w,r)
def solution2():
return np.squeeze(np.matmul(w.transpose(2, 0, 1), r.T[..., None])).T
def solution3():
Wres = np.empty((8, 25000))
for i in range(0,25000):
Wres[:,i] = np.matmul(w[:,:,i],r[:,i])
return Wres
%timeit solution1()
100 loops, best of 3: 2.51 ms per loop
%timeit solution2()
100 loops, best of 3: 2.52 ms per loop
%timeit solution3()
10 loops, best of 3: 64.2 ms per loop
Credit to: @Divakar
Upvotes: 2
Reputation: 15889
Matmul can propagate as long as the 2 arrays share the same 1 axis length. From the docs:
If either argument is N-D, N > 2, it is treated as a stack of matrices residing in the last two indexes and broadcast accordingly.
Thus, you have to perform 2 operations prior to matmul
:
import numpy as np
a = np.random.rand(8,8,100)
b = np.random.rand(8, 100)
a
and b
so that the first axis are the 100 slicesb
so that b.shape = (100, 8, 1)
Then:
at = a.transpose(2, 0, 1) # swap to shape 100, 8, 8
bt = b.T[..., None] # swap to shape 100, 8, 1
c = np.matmul(at, bt)
c
is now 100, 8, 1
, reshape back to 8, 100
:
c = np.squeeze(c).swapaxes(0, 1)
or
c = np.squeeze(c).T
And last, a one-liner just for conveniende:
c = np.squeeze(np.matmul(a.transpose(2, 0, 1), b.T[..., None])).T
Upvotes: 3