Andrea Goldoni
Andrea Goldoni

Reputation: 45

C++: fast modular exponentation


I'm a computer science student and i have a problem where i must use the fast modular exponentation.
With this code that i made, the corrector says me that the output in some cases is incorrect, but it shouldn't be.

unsigned long long int pow(int a, int n, int M)
{
    if(n==0)
        return 1;
    if(n==1)
        return a;
    unsigned long long tmp=pow(a, n/2, M)%M;
    if(n%2==0)
        return ((tmp)*(tmp))%M;
    return ((tmp*tmp)*(a%M))%M;
}

Instead with this other code i pass all the test cases.

unsigned long long int pow(int a, int n, int M)
{
    if(n==0)
        return 1;
    if(n==1)
        return a;
    unsigned long long tmp;
    if(n%2==0){
        tmp=pow(a, n/2, M)%M;
        return (tmp*tmp)%M;
    }
    tmp=pow(a, n-1, M)%M;
    return (tmp*(a%M))%M;
}

So my question is why with the first code i don't pass all the test cases?

Upvotes: 0

Views: 102

Answers (1)

Evg
Evg

Reputation: 26272

First, if n == 1, the return value should be a % M, not a. Second, the product (tmp * tmp) * (a % M) can overflow, and should be computed as ((tmp * tmp) % M) * (a % M).

The condition n == 1 doesn't need any special treatment, and the code can be simplified to:

unsigned long long int pow(unsigned int a, unsigned int n, unsigned int m) {
    if (n == 0)
        return 1;

    auto tmp = pow(a, n / 2, m) % m;
    tmp *= tmp;
    if (n % 2)
        tmp = (tmp % m) * (a % m);
    return tmp % m;
}

Upvotes: 2

Related Questions