Reputation: 1134
I'm trying to implement the divide and conquer matrix multiplication (8 recursion version not Strassen). I thought I had it figured out but it is producing weird output with too many nested lists and the wrong values. I suspect the problem is how I'm summing the 8 recursions but Im not sure.
def multiMatrix(x,y):
n = len(x)
if n == 1:
return x[0][0] * y[0][0]
else:
a = [[col for col in row[:len(row)/2]] for row in x[:len(x)/2]]
b = [[col for col in row[len(row)/2:]] for row in x[:len(x)/2]]
c = [[col for col in row[:len(row)/2]] for row in x[len(x)/2:]]
d = [[col for col in row[len(row)/2:]] for row in x[len(x)/2:]]
e = [[col for col in row[:len(row)/2]] for row in y[:len(y)/2]]
f = [[col for col in row[len(row)/2:]] for row in y[:len(y)/2]]
g = [[col for col in row[:len(row)/2]] for row in y[len(y)/2:]]
h = [[col for col in row[len(row)/2:]] for row in y[len(y)/2:]]
ae = multiMatrix(a,e)
bg = multiMatrix(b,g)
af = multiMatrix(a,f)
bh = multiMatrix(b,h)
ce = multiMatrix(c,e)
dg = multiMatrix(d,g)
cf = multiMatrix(c,f)
dh = multiMatrix(d,h)
c = [[ae+bg,af+bh],[ce+dg,cf+dh]]
return c
a = [
[1,2,3,4],
[5,6,7,8],
[9,10,11,12],
[13,14,15,16]
]
b = [
[1,2,3,4],
[5,6,7,8],
[9,10,11,12],
[13,14,15,16]
]
print multiMatrix(a,b)
Upvotes: 1
Views: 4688
Reputation: 11
def join_horiz(a, b):
return [rowa + rowb for rowa, rowb in zip(a,b)]
def MatAdd(A,B):
resultant = [[0 for i in range(len(A))] for j in range(len(A))]
for i in range(len(A)):
for j in range(len(A)):
resultant[i][j] = A[i][j] + B[i][j]
return resultant
def createSubmatrices(A,starting_index,rows,columns):
resultant = [[0 for i in range(rows)] for j in range(columns)]
for i in range(rows):
for j in range(columns):
resultant[i][j] = A[starting_index[0] + i][starting_index[1] + j]
return resultant
def MatMulRecursive(A,B,n):
if(n==1):
return [[A[0][0]*B[0][0]]]
else:
A11 = createSubmatrices(A, (0,0), n//2, n//2)
A12 = createSubmatrices(A, (0,n//2), n//2, n//2)
A21 = createSubmatrices(A, (n//2,0), n//2, n//2)
A22 = createSubmatrices(A, (n//2,n//2), n//2, n//2)
B11 = createSubmatrices(B, (0,0), n//2, n//2)
B12 = createSubmatrices(B, (0,n//2), n//2, n//2)
B21 = createSubmatrices(B, (n//2,0), n//2, n//2)
B22 = createSubmatrices(B, (n//2,n//2), n//2, n//2)
C11 = list(MatAdd(MatMulRecursive(A11, B11, n//2) , MatMulRecursive(A12, B21, n//2)))
C12 = list(MatAdd(MatMulRecursive(A11, B12, n//2) , MatMulRecursive(A12, B22, n//2)))
C21 = list(MatAdd(MatMulRecursive(A21, B11, n//2) , MatMulRecursive(A22, B21, n//2)))
C22 = list(MatAdd(MatMulRecursive(A21, B12, n//2) , MatMulRecursive(A22, B22, n//2)))
return join_horiz(C11, C12) + join_horiz(C21, C22)
A = [[1,1,1,1], [1,5,5,1], [1,7,7,1], [3,3,3,2]]
B = [[2,2,2,2], [2,2,2,2], [2,2,2,2], [2,2,2,2]]
C = MatMulRecursive(A, B, 4)
print(C)
Upvotes: 1
Reputation: 1
if we give recursive function only two matrices then the code will be more clean
Upvotes: 0
Reputation: 53089
Your suspicion is correct, your matrices are still lists, so adding them will just make a longer list.
Try using something like this
def matrix_add(a, b):
return [[ea+eb for ea, eb in zip(*rowpair)] for rowpair in zip(a, b)]
in your code.
To join blocks:
def join_horiz(a, b):
return [rowa + rowb for rowa, rowb in zip(a,b)]
def join_vert(a, b):
return a+b
Finally, to make it all work together I think you have to change your special case for 1 to
return [[x[0][0] * y[0][0]]]
Edit:
I just realised that this will only work for power-of-two dimensions. Otherwise you will have to deal with non-square matrices and it will happen that x
is 1 x something and your special case won't work. So you'll also have to check for len(x[0]) (if n > 0).
Upvotes: 3