b_pcakes
b_pcakes

Reputation: 2540

Understanding fast exponentiation function

I'm having trouble understanding why this function works? Can someone explain what it's doing step by step? I know the idea is that a^n is equal to (a^(n/2))^2 if n is even or a(a^((n-1)/2))^2 if n is odd, but how is this function doing that?

double pow(double a, int n) {
    double ret = 1;
    while(n) {
        if(n%2 == 1) ret *= a;
        a *= a; n /= 2;
    }
    return ret;
}

Upvotes: 1

Views: 212

Answers (3)

gnasher729
gnasher729

Reputation: 52592

I'll start with some code that is more obvious:

double pow(double a, int n) {
    int k = 0, m = 1, n2 = n;
    double pow_k = 1.0, pow_m = a;
    assert (n2 * m + k == n);

    while (n2 != 0) {
        if (n2 % 2 != 0) { k += m; pow_k *= pow_m; n2 -= 1; }
        assert (n2 * m + k == n); assert (n2 % 2 == 0);
        m = m * 2; pow_m = pow_m * pow_m; n2 /= 2;
        assert (n2 * m + k == n);
    }

    return pow_k;
}

At every point in the loop, pow_k = a^k and pow_m = a^m. n2 * m + k == n is always true. It is initially true when n2 == n, m == 1, k == 0.

Before the first if-statement in the loop, either n2 is even, so the assert remains true and n2 remains even. Or n2 is odd. In that case n2 is decreased by 1 which decreases n2 * m by m; k is increased by m, leaving n2 * m + k unchanged. And n2 is made even.

Then m is doubled and n2 exactly halved because n2 is even, leaving n2 * m + k again unchanged.

Since n2 is divided by 2 in each iteration, n2 eventually becomes 0, so the loop ends. The assert with n2 == 0 means 0 * m + k == n or k == n, so pow_k = a^k = a^n. The result that is returned is therefore a^n.

Now we leave out k, m and the asserts, which doesn't change the calculation:

double pow(double a, int n) {
    int n2 = n;
    double pow_k = 1.0, pow_m = a;

    while (n2 != 0) {
        if (n2 % 2 != 0) { pow_k *= pow_m; n2 -= 1; }
        m = m * 2; pow_m = pow_m * pow_m; n2 /= 2;
    }

    return pow_k;
}

We can remove n2 -= 1 when n2 is odd because after dividing by 2 it doesn't make a difference. And since n isn't used, we can just use n instead of n2:

double pow(double a, int n) {
    double pow_k = 1.0, pow_m = a;

    while (n != 0) {
        if (n % 2 != 0) pow_k *= pow_m;
        pow_m = pow_m * pow_m; n /= 2;
    }

    return pow_k;
}

Now we change pow_k to ret, and pow_m to a, and change n % 2 != 0 to n % 2 == 1, and we get the original code:

double pow(double a, int n) {
    double ret = 1.0;

    while (n != 0) {
        if (n % 2 == 1) ret *= a;
        a *= a; n /= 2;
    }

    return ret;
}

Upvotes: 1

Renzo
Renzo

Reputation: 27424

The equalities that are used in this program are the followings:

  1. The invariant of the loop is: (at each step of the loop), a^n * ret is the result. In fact, at the beginning ret is 1, while at the end of the loop n == 0, so that a^0 * ret is the result, and since a^0 == 1, ret is the expected result.
  2. If n is odd, (i.e. n%2 == 1), then there exists a b≥0 such that n=b*2+1. In this case, we use the following equality: a^(b*2+1)=(a^(b*2))*a. So ret is multiplied by a.
  3. In the next statement, the following equality is used: a^(b*2) = (a^2)^b, so that a is multiplied by itself and n is divided by 2, and the invariant is finally maintained.

Note that inside the loop, the integer division is used in n /= 2, so that the result is always b in both cases (n odd, that is n=b*2+1, or n is even, that is n=b*2).

Finally, note that, as pointed out by @chux in a comment, the function does not manage correctly negative values of n.

Upvotes: 3

giliev
giliev

Reputation: 3058

Here is my Python recursive code which is IMO more readable and understandable (I know it is not good idea to create recursive functions in Python, but I chose Python because of its simple syntax to demonstrate the idea).

def pow(n, e):
    if e == 0:
        return 1

    if e % 2 == 1:
        return n * pow(n, e - 1)

    # this step makes the algorithm to run in O(lg n) time
    tmp = pow(n, e / 2)

    return tmp * tmp

I will stress once again, tmp = pow(n, e / 2) is the line where is reduced time complexity.

The algorithm instead of multiplying e times the number n, reuses some previously calculated results. For example 2^8 will be calculated as 2^4 * 2^4. Here 2^4 will be calculated only once, and half of the iterations will be skipped that way. Same for 2^4, etc.

I tried to explain it somehow more intuitively, without going deeply in the theory behind this optimization. If you want to understand it more deeply and how it works on bit level, here is a good tutorial

Upvotes: 1

Related Questions