user3799584
user3799584

Reputation: 927

How to generalize elementwise matrix operations in numpy

Suppose I have a two sets of n vectors represented by a 3xn arrays V and W and a set of n matrices represented by a 3x3xn array Q. How do I vectorize operations to give me

a) the set of n vectors Q[:,:,k]*V[:,k] for k in range(n)

b) the set of n scalars W[:,k].T*Q[:,:,k]*V[:,k] for k in range(n)

Here the multiplication is to be interpreted as matrix multiplication. Can't make sense of the einsum documentation which I think is what should be used. Any clarification/solution would be appreciated.

Upvotes: 2

Views: 130

Answers (1)

hpaulj
hpaulj

Reputation: 231605

a) the set of n vectors Q[:,:,k]*V[:,k] for k in range(n)

np.einsum('ijk,jk->ik', Q, v)

should produce a (3,n) array. Matrix product summation is over j.

b) the set of n scalars W[:,k].T*Q[:,:,k]*V[:,k] for k in range(n)

np.einsum('ik,ijk,jk->k', W, Q, V)

I working from memory, and my best guess as to what you need. So my 'ij' expressions might need adjustments. But give these a try and let me know how it works.


testing

In [180]: V=W=np.arange(3*n).reshape(3,n)

In [181]: Q=np.arange(3*3*n).reshape(3,3,n)

In [182]: np.einsum('ijk,jk->ik',Q,V)
Out[182]: 
array([[ 80, 107, 140, 179],
       [224, 287, 356, 431],
       [368, 467, 572, 683]])

In [183]: np.einsum('ik,ijk,jk',W,Q,V)
Out[183]: 28788    # summation over k

In [184]: np.einsum('ik,ijk,jk->k',W,Q,V)
Out[184]: array([ 3840,  5745,  8136, 11067])

Some times breaking einsum into several steps is faster, since it keeps the iteration space from getting too big. I don't think that's the case here, but here's what that would look like.

In [185]: np.einsum('jk,jk->k',np.einsum('ik,ijk->jk',W,Q),V)
Out[185]: array([ 3840,  5745,  8136, 11067])

and using Jaime's comment:

In [186]: np.einsum('i...,ij...,j...',W,Q,V)
Out[186]: array([ 3840,  5745,  8136, 11067])

In [187]: np.einsum('ij...,j...->i...',Q,V)
Out[187]: 
array([[ 80, 107, 140, 179],
       [224, 287, 356, 431],
       [368, 467, 572, 683]])

Upvotes: 3

Related Questions