Reputation: 253
I'm trying to get the dot products of each element in an nx2x3 array and an nx3 array (the value of n is always shared between the two).
For example:
import numpy as np
a = np.arange(12).reshape(4,3)
b = np.arange(24).reshape(4,2,3)
The array I'm trying to get would contain these:
print(np.dot(b[0],a[0]))
print(np.dot(b[1],a[1]))
print(np.dot(b[2],a[2]))
print(np.dot(b[3],a[3]))
I'm sure there's a way to use einsum
or tensordot
for this but I can't get it to work.
Upvotes: 1
Views: 74
Reputation: 176810
You could use einsum
this way:
>>> np.einsum('ij,ikj->ik', a, b)
array([[ 5, 14],
[ 86, 122],
[275, 338],
[572, 662]])
All that's happening here is axis 0 of a
is multiplied with axis 0 of b
, and axis 1 of a
is multiplied with axis 2 of b
. Values along the latter axis are summed and a 2D array is returned.
(tensordot
doesn't apply itself neatly to this particular problem as we need multiplication along two axes and summation along just one. These operations only come in pairs with tensordot
.)
Upvotes: 3