Reputation: 59
I came up with this solution for the following question from "Cracking the coding interview" I think it's faster and more elegant from what I saw in their solution but not sure it works on all cases. Can anyone tell if there is some edge case I've missed?
The original question from CTCI is(Question 8.5 on recursion and dynamic programming):
Write a recursive function to multiply two positive integers without using the * operator. You can use addition, subtraction, and bit shifting, but you should minimize the number of those operations.
My solution is:
def foo(a, b, mult = 0):
if a == 0 or b == 0:
return 0
if (a & 1) == 1:
return (b << mult) + foo(a >> 1, b, mult + 1)
return foo(a >> 1, b, mult + 1)
Can anyone tell if I've missed something?
Upvotes: 2
Views: 783
Reputation: 15962
Using the basis of multiplication, that is, 5 x 3 is "5 added 3 times" (5 + 5 + 5) or "3 added 5 times" (3 + 3 + 3 + 3 + 3).
This solution has 1 addition and 1 subtraction:
def multiply(a, b):
if b == 1:
return a
return a + multiply(a, b-1) # addition and subtraction
multiply(5, 3)
# 15
multiply(7, 4)
# 28
multiply(2000, 1000)
# 2000000
You could reduce the number of times recursion occurs by assigning b
as the smaller of the two numbers - put this at the start of the function:
a, b = max(a, b), min(a, b)
# or
a, b = (a, b) if a >= b else (b, a)
Combining those and making an almost-one-liner:
def multiply(a, b):
a, b = (a, b) if a >= b else (b, a)
return a if b == 1 else a + multiply(a, b-1)
multiply(2, 1_000_000)
# 2000000
Upvotes: 0
Reputation: 22314
Bit-shifting can be used to do a multiplication by a multiple of 2. It helps to write down your multiplications in binary to see how that helps.
Consider 5 * 7
, in binary that is 101 * 111
. Applying distributivity, you get:
101 * 111 = 111 * (100 + 001)
= 111 * 100 + 111 * 001
= 7 * 4 + 7 * 1
= 7 * 2² + 7*2⁰
In general, given a multiplication X * Y
, you can partition X
into a sum of n
powers of 2 and write the multiplication as the sum of terms of the form 2^k * Y
. The terms which appear in you series correspond to those indices where you have a 1
in the binary representation of you X
.
The recursive algorithm is thus to look at the binary digits of one of your operands from right to left and sum the terms where you see a 1
.
def mult(a, b):
if a & 1:
return b + mult(a >> 1, b << 1)
if a:
return mult(a >> 1, b << 1)
else:
return 0
print(mult(9, 7)) # 63
You algorithm is very close, but notice that you do not need to keep your mult
argument. Instead you can simply shift your right operand by 1
everytime you look at the next digit.
Also note that this does not terminate if a < 0
, but this is easy to fix by doing the positive multiplication and adding back the sign when a
is negative.
Upvotes: 1
Reputation: 42143
You can use bit shifting to perform the multiplication in base 2:
def multiply(A,B):
result = 0
while A:
if A&1: result = result + B # Add shifted B if last bit of A is 1
A >>= 1 # next bit of A
B <<= 1 # shift B to next bit position
return result
print(multiply(7,13)) # 91
If you are not allowed to use the &
operator if (A>>1)<<1 != A:
will do the same thing as if A&1:
(i.e. test last bit).
If you look at the bit composition of a multiplication you can see what the bit shifting is doing:
# bitPosition A=13 (1101) B=7 (111) A*B=91 (1011011)
#
# 0 1 111 111 7
# 1 0 1110
# 2 1 11100 11100 28
# 3 1 111000 111000 56
# ------ --
# 1011011 91
From that, it is relatively easy to write a recursive version:
def multiply(A,B):
if A == 0: return 0
result = multiply(A>>1,B)<<1 # get the higher bit product
if A&1: result = result + B # add the last bit multiplier (B)
return result
Upvotes: 1