Reputation: 3806
But obviously I'm doing something wrong.
I've been chasing a bug all night, and I've finally solved it. Consider:
xs = np.arange(100 * 3).reshape(100, 3)
W = np.arange(3 * 17).reshape(3, 17)
a = np.einsum('df, hg -> dg', xs, W)
b = np.dot(xs, W)
In the above a != b
.
The issue I discovered was in the einsum, I say df, hg -> dg
, but if I instead swap out that h
for an f
, it works as expected:
a = np.einsum('df, fg -> dg', xs, W)
b = np.dot(xs, W)
Now, a == b
.
What is the summation doing differently in both cases, I'd expect them to be the same?
Upvotes: 4
Views: 277
Reputation: 53029
Here are equivalent broadcasting-based expressions, perhaps they help you understand the difference:
dffg = (xs[:,:,None]*W[None,:,:]).sum(1)
dfhg = (xs[:,:,None,None]*W[None,None,:,:]).sum((1,2))
(a==dfhg).all()
# True
(b==dffg).all()
# True
In the dfhg case the data axes do not actually overlap; therefore the summation can be done on each term separately:
dfhg_ = (xs.sum(1)[:,None]*W.sum(0)[None,:])
(a==dfhg_).all()
# True
Contrast this with the dffg case where a dot product is formed between each row of xs and each column of W.
Upvotes: 2
Reputation: 25518
The correct way to do the matrix multiplication using np.einsum
is to repeat the "middle" index (indicating summation over row times column), as you found:
a = np.array([[1,2],[3,4]])
b = np.array([[1,-2],[-0.4,3]])
np.einsum('df,fg->dg', a, b)
array([[ 0.2, 4. ],
[ 1.4, 6. ]])
a.dot(b)
array([[ 0.2, 4. ],
[ 1.4, 6. ]])
If you don't, you get each value of a
multiplied by b
:
np.einsum('df, hg -> dfhg', a, b)
array([[[[ 1. , -2. ],
[ -0.4, 3. ]],
[[ 2. , -4. ],
[ -0.8, 6. ]]],
[[[ 3. , -6. ],
[ -1.2, 9. ]],
[[ 4. , -8. ],
[ -1.6, 12. ]]]])
is the same as
a[:,:, None, None] * b
When you omit the middle indices in your use of the explicit operator ->
, you sum over these axes:
np.einsum('df, hg -> dg', a, b)
array([[ 1.8, 3. ],
[ 4.2, 7. ]])
is the same as:
np.einsum('df, hg -> dfhg', a, b).sum(axis=1).sum(axis=1)
Here is a good guide to einsum
(not mine).
Upvotes: 1