Reputation: 107
I am trying to implement Karatsuba multiplication on Python. The inputs are two integers of length power of 2. They are of same length.
def mult(x,y):
if int(x) < 10 and int(y) <10:
return int(x)*int(y)
x_length = len(str(x))//2
y_length = len(str(y))//2
a = str(x)[:x_length]
b = str(x)[x_length:]
c = str(y)[:y_length]
d = str(y)[y_length:]
n = len(a) + len(b)
m = n//2
return 10**n* mult(a,c) + 10**m*(mult(a+b, c+d)-mult(a,c)-mult(b,d)) + mult(b,d)
Running
mult(1234,5678)
This gives the following error:
if int(x) < 10 and int(y) <10:
RecursionError: maximum recursion depth exceeded while calling a Python object
However if I do
def mult(x,y):
if int(x) < 10 and int(y) <10:
return int(x)*int(y)
x_length = len(str(x))//2
y_length = len(str(y))//2
a = str(x)[:x_length]
b = str(x)[x_length:]
c = str(y)[:y_length]
d = str(y)[y_length:]
n = len(a) + len(b)
m = n//2
return 10**n* mult(a,c) + 10**m*(mult(a,d)+mult(b,c)) + mult(b,d)
So I am doing 4 recursions in the last line (i.e. mult(a,c), mult(a,d), mult(b,c), mult(b,d)
) rather than 3 as in the above (i.e. mult(a,c), mult(a+b, c+d), mult(b,d)
).
Then it turns out ok.
Why is this happening? And how can I do it with only 3 recursions?
Upvotes: 0
Views: 401
Reputation: 7923
a, b, c, d
are strings. String addition is concatenation. "1" + "2"
is "12"
. So what is passed to mult(a+b, c+d)
is not what you intended to pass.
TL;DR.
First thing first, the recursion is supposed to terminate quickly. Let's see why it doesn't. Add print x, y
at the beginning of mult
:
def mult(x, y):
print x, y
....
and redirect the output into a file. The result is surprising:
1234 5678
12 56
1 5
12 56
1 5
12 56
1 5
12 56
1 5
....
No wonder the stack overflows. Question is, why we repeat the 12 56
case? Let's add more instrumentation, to find out which recursive call does that:
def mult(x,y,k=-1):
....
print a, b, c, d
ac = mult(a, c, 0)
bd = mult(b, d, 2)
return 10**n* ac + 10**m*(mult(a+b, c+d, 1) - ac - bd) + bd
The results are
-1 : 1234 5678
12 34 56 78
0 : 12 56
1 2 5 6
0 : 1 5
2 : 2 6
1 : 12 56
1 2 5 6
0 : 1 5
2 : 2 6
1 : 12 56
1 2 5 6
0 : 1 5
2 : 2 6
1 : 12 56
You can see that the recursive call marked 1
always gets 12 56
. It is the call which computes mult(a + b, c + d)
. Oh well. All of them a, b, c, d
are strings. "1" + "2"
is "12"
. Not exactly what you've meant.
So, make up your mind: are the parameters integer or strings, and treat them accordingly.
Upvotes: 1
Reputation: 46
Note that in your first code snippet - you are calling your function not thrice, but 5 times:
return 10**n* mult(a,c) + 10**m*(mult(a+b, c+d)-mult(a,c)-mult(b,d)) + mult(b,d)
I can't really say for the rest of your code, but taking a quick look at the Wikipedia entry on Karatsuba, you can decrease your recursion depth by increasing the base number you are using (i.e. from 10 to 100 or 1000). You can change your recursion depth using sys.setrecursionlimit
but python stack frames can get quite big, so try to avoid doing so as it may be dangerous.
Upvotes: 0