Reputation: 2540
I'm having trouble understanding why this function works? Can someone explain what it's doing step by step? I know the idea is that a^n is equal to (a^(n/2))^2 if n is even or a(a^((n-1)/2))^2 if n is odd, but how is this function doing that?
double pow(double a, int n) {
double ret = 1;
while(n) {
if(n%2 == 1) ret *= a;
a *= a; n /= 2;
}
return ret;
}
Upvotes: 1
Views: 212
Reputation: 52592
I'll start with some code that is more obvious:
double pow(double a, int n) {
int k = 0, m = 1, n2 = n;
double pow_k = 1.0, pow_m = a;
assert (n2 * m + k == n);
while (n2 != 0) {
if (n2 % 2 != 0) { k += m; pow_k *= pow_m; n2 -= 1; }
assert (n2 * m + k == n); assert (n2 % 2 == 0);
m = m * 2; pow_m = pow_m * pow_m; n2 /= 2;
assert (n2 * m + k == n);
}
return pow_k;
}
At every point in the loop, pow_k = a^k and pow_m = a^m. n2 * m + k == n is always true. It is initially true when n2 == n, m == 1, k == 0.
Before the first if-statement in the loop, either n2 is even, so the assert remains true and n2 remains even. Or n2 is odd. In that case n2 is decreased by 1 which decreases n2 * m by m; k is increased by m, leaving n2 * m + k unchanged. And n2 is made even.
Then m is doubled and n2 exactly halved because n2 is even, leaving n2 * m + k again unchanged.
Since n2 is divided by 2 in each iteration, n2 eventually becomes 0, so the loop ends. The assert with n2 == 0 means 0 * m + k == n or k == n, so pow_k = a^k = a^n. The result that is returned is therefore a^n.
Now we leave out k, m and the asserts, which doesn't change the calculation:
double pow(double a, int n) {
int n2 = n;
double pow_k = 1.0, pow_m = a;
while (n2 != 0) {
if (n2 % 2 != 0) { pow_k *= pow_m; n2 -= 1; }
m = m * 2; pow_m = pow_m * pow_m; n2 /= 2;
}
return pow_k;
}
We can remove n2 -= 1 when n2 is odd because after dividing by 2 it doesn't make a difference. And since n isn't used, we can just use n instead of n2:
double pow(double a, int n) {
double pow_k = 1.0, pow_m = a;
while (n != 0) {
if (n % 2 != 0) pow_k *= pow_m;
pow_m = pow_m * pow_m; n /= 2;
}
return pow_k;
}
Now we change pow_k to ret, and pow_m to a, and change n % 2 != 0 to n % 2 == 1, and we get the original code:
double pow(double a, int n) {
double ret = 1.0;
while (n != 0) {
if (n % 2 == 1) ret *= a;
a *= a; n /= 2;
}
return ret;
}
Upvotes: 1
Reputation: 27424
The equalities that are used in this program are the followings:
a^n * ret
is the result. In fact, at the beginning ret
is 1
, while at the end of the loop n == 0
, so that a^0 * ret
is the result, and since a^0 == 1
, ret
is the expected result.n
is odd, (i.e. n%2 == 1
), then there exists a b≥0
such that n=b*2+1
. In this case, we use the following equality: a^(b*2+1)=(a^(b*2))*a
. So ret
is multiplied by a
. a^(b*2) = (a^2)^b
, so that a
is multiplied by itself and n
is divided by 2, and the invariant is finally maintained.Note that inside the loop, the integer division is used in n /= 2
, so that the result is always b
in both cases (n
odd, that is n=b*2+1
, or n
is even, that is n=b*2
).
Finally, note that, as pointed out by @chux in a comment, the function does not manage correctly negative values of n
.
Upvotes: 3
Reputation: 3058
Here is my Python recursive code which is IMO more readable and understandable (I know it is not good idea to create recursive functions in Python, but I chose Python because of its simple syntax to demonstrate the idea).
def pow(n, e):
if e == 0:
return 1
if e % 2 == 1:
return n * pow(n, e - 1)
# this step makes the algorithm to run in O(lg n) time
tmp = pow(n, e / 2)
return tmp * tmp
I will stress once again, tmp = pow(n, e / 2)
is the line where is reduced time complexity.
The algorithm instead of multiplying e times the number n, reuses some previously calculated results. For example 2^8 will be calculated as 2^4 * 2^4. Here 2^4 will be calculated only once, and half of the iterations will be skipped that way. Same for 2^4, etc.
I tried to explain it somehow more intuitively, without going deeply in the theory behind this optimization. If you want to understand it more deeply and how it works on bit level, here is a good tutorial
Upvotes: 1