Reputation: 11
Can anyone pls help me understand how this matrix multiplication really works?
def matrix_mul(a, b):
return [[sum(i * j for i, j in zip(r, c)) for c in zip(*b)]
for r in a]
a = [[1, 2], [3, 4]]
b = [[5, 1], [2, 1]]
c = matrix_mul(a, b)
Upvotes: 1
Views: 3333
Reputation: 11
So by decalring :
def matrix_mul(a, b):
return [[sum(i * j for i, j in zip(r, c)) for c in zip(*b)]
for r in a]
U are actually doing the following :
a = [[1, 2], [3, 4]]
b = [[5, 1], [2, 1]]
First element of array B (5) multiplied by first element of array A (1) + snd element of array A (2) multiplied by first elem of array B[1] (2). So for obtain c u need to do :
c = [[9,3],[23,7]]
9 = a[0][0] * b[0][0] + a[0][1] * b[1][0]
3 = a[0][0] * b[0][1] + a[0][1] * b[1][1]
23 = a[1][0] * b[0][0] + a[1][1] * b[1][0]
7 = a[1][0] * b[1][0] + a[1][1] * b[1][1]
I think that this is the correct asnwer. Hope i helped you
Upvotes: 0
Reputation: 13079
Your expression has essentially three nested list comprehensions (although in fact one of them is a generator expression). If you replace them with explicit loops, add some appropriately named variables for the lists which are being built up, and add some print statements, then you can see what is going on:
def matrix_mul(a, b):
matrix_out = []
for r in a:
print(" Input row of a", r)
row_out = []
for c in zip(*b):
print(" Input column of b", c)
products = []
for i, j in zip(r, c):
print(" Values to multiply:", i, j)
inputs = i, j
products.append(i * j)
print(" List of products:", products)
print(" Sum of products:", sum(products))
row_out.append(sum(products))
matrix_out.append(row_out)
print(" Row of answer", row_out)
print()
print("Answer", matrix_out)
return matrix_out
a = [[1, 2], [3, 4]]
b = [[5, 1], [2, 1]]
c = matrix_mul(a, b)
Gives:
Input row of a [1, 2]
Input column of b (5, 2)
Values to multiply: 1 5
Values to multiply: 2 2
List of products: [5, 4]
Sum of products: 9
Input column of b (1, 1)
Values to multiply: 1 1
Values to multiply: 2 1
List of products: [1, 2]
Sum of products: 3
Row of answer [9, 3]
Input row of a [3, 4]
Input column of b (5, 2)
Values to multiply: 3 5
Values to multiply: 4 2
List of products: [15, 8]
Sum of products: 23
Input column of b (1, 1)
Values to multiply: 3 1
Values to multiply: 4 1
List of products: [3, 4]
Sum of products: 7
Row of answer [23, 7]
Answer [[9, 3], [23, 7]]
To explain the zip(*b)
using the example values:
b = [[5, 1], [2, 1]]
the *
will make it expand the top-level list to a number of separate arguments
zip([5, 1], [2, 1])
And the zip will then yield a sequence of tuples consisting of the first element from each input list, then the second element of each input list (and so on, although in this case there are only two pairs), i.e.
(5, 2)
(1, 1)
In other words, the columns of b
.
(Note: because the input to sum
in the original code is actually a generator expression rather than a list comprehension, it does not in fact produce a list of products, but rather generates a sequence of products that are consumed by the sum
function one at a time. But aside from some efficiency gain in doing it this way, the effect is basically the same.)
Upvotes: 3