Reputation: 1202
So I've been working recently on an implementation of the Miller-Rabin primality test. I am limiting it to a scope of all 32-bit numbers, because this is a just-for-fun project that I am doing to familiarize myself with c++, and I don't want to have to work with anything 64-bits for awhile. An added bonus is that the algorithm is deterministic for all 32-bit numbers, so I can significantly increase efficiency because I know exactly what witnesses to test for.
So for low numbers, the algorithm works exceptionally well. However, part of the process relies upon modular exponentiation, that is (num ^ pow) % mod. so, for example,
3 ^ 2 % 5 =
9 % 5 =
here is the code I have been using for this modular exponentiation:
unsigned mod_pow(unsigned num, unsigned pow, unsigned mod)
unsigned test;
for(test = 1; pow; pow >>= 1)
if (pow & 1)
test = (test * num) % mod;
num = (num * num) % mod;
return test;
As you might have already guessed, problems arise when the arguments are all exceptionally large numbers. For example, if I want to test the number 673109 for primality, I will at one point have to find:
(2 ^ 168277) % 673109
now 2 ^ 168277 is an exceptionally large number, and somewhere in the process it overflows test, which results in an incorrect evaluation.
on the reverse side, arguments such as
4000111222 ^ 3 % 1608
also evaluate incorrectly, for much the same reason.
Does anyone have suggestions for modular exponentiation in a way that can prevent this overflow and/or manipulate it to produce the correct result? (the way I see it, overflow is just another form of modulo, that is num % (UINT_MAX+1))
Upvotes: 18
Views: 32327
Reputation: 1794
is for long long int
LL power_mod(LL a, LL k) {
if (k == 0)
return 1;
LL temp = power(a, k/2);
LL res;
res = ( ( temp % P ) * (temp % P) ) % P;
if (k % 2 == 1)
res = ((a % P) * (res % P)) % P;
return res;
Use the above recursive function for finding the mod exp of the number. This will not result in overflow because it calculates in a bottom up manner.
Sample test run for :
a = 2
and k = 168277
shows output to be 518358 which is correct and the function runs in O(log(k))
Upvotes: 0
Reputation: 909
package playTime;
public class play {
public static long count = 0;
public static long binSlots = 10;
public static long y = 645;
public static long finalValue = 1;
public static long x = 11;
public static void main(String[] args){
int[] binArray = new int[]{0,0,1,0,0,0,0,1,0,1};
x = BME(x, count, binArray);
System.out.print("\nfinal value:"+finalValue);
public static long BME(long x, long count, int[] binArray){
if(count == binSlots){
return finalValue;
if(binArray[(int) count] == 1){
finalValue = finalValue*x%y;
x = (x*x)%y;
System.out.print("Array("+binArray[(int) count]+") "
+"x("+x+")" +" finalVal("+ finalValue + ")\n");
return BME(x, count,binArray);
Upvotes: 1
Reputation: 3074
I wrote something for this recently for RSA in C++, bit messy though.
#include "BigInteger.h"
#include <iostream>
#include <sstream>
#include <stack>
BigInteger::BigInteger() {
negative = false;
BigInteger::~BigInteger() {
void BigInteger::addWithoutSign(BigInteger& c, const BigInteger& a, const BigInteger& b) {
int sum_n_carry = 0;
int n = (int)a.digits.size();
if (n < (int)b.digits.size()) {
n = b.digits.size();
for (int i = 0; i < n; ++i) {
unsigned short a_digit = 0;
unsigned short b_digit = 0;
if (i < (int)a.digits.size()) {
a_digit = a.digits[i];
if (i < (int)b.digits.size()) {
b_digit = b.digits[i];
sum_n_carry += a_digit + b_digit;
c.digits[i] = (sum_n_carry & 0xFFFF);
sum_n_carry >>= 16;
if (sum_n_carry != 0) {
putCarryInfront(c, sum_n_carry);
while (c.digits.size() > 1 && c.digits.back() == 0) {
//std::cout << a.toString() << " + " << b.toString() << " == " << c.toString() << std::endl;
void BigInteger::subWithoutSign(BigInteger& c, const BigInteger& a, const BigInteger& b) {
int sub_n_borrow = 0;
int n = a.digits.size();
if (n < (int)b.digits.size())
n = (int)b.digits.size();
for (int i = 0; i < n; ++i) {
unsigned short a_digit = 0;
unsigned short b_digit = 0;
if (i < (int)a.digits.size())
a_digit = a.digits[i];
if (i < (int)b.digits.size())
b_digit = b.digits[i];
sub_n_borrow += a_digit - b_digit;
if (sub_n_borrow >= 0) {
c.digits[i] = sub_n_borrow;
sub_n_borrow = 0;
} else {
c.digits[i] = 0x10000 + sub_n_borrow;
sub_n_borrow = -1;
while (c.digits.size() > 1 && c.digits.back() == 0) {
//std::cout << a.toString() << " - " << b.toString() << " == " << c.toString() << std::endl;
int BigInteger::cmpWithoutSign(const BigInteger& a, const BigInteger& b) {
int n = (int)a.digits.size();
if (n < (int)b.digits.size())
n = (int)b.digits.size();
//std::cout << "cmp(" << a.toString() << ", " << b.toString() << ") == ";
for (int i = n-1; i >= 0; --i) {
unsigned short a_digit = 0;
unsigned short b_digit = 0;
if (i < (int)a.digits.size())
a_digit = a.digits[i];
if (i < (int)b.digits.size())
b_digit = b.digits[i];
if (a_digit < b_digit) {
//std::cout << "-1" << std::endl;
return -1;
} else if (a_digit > b_digit) {
//std::cout << "+1" << std::endl;
return +1;
//std::cout << "0" << std::endl;
return 0;
void BigInteger::multByDigitWithoutSign(BigInteger& c, const BigInteger& a, unsigned short b) {
unsigned int mult_n_carry = 0;
for (int i = 0; i < (int)a.digits.size(); ++i) {
unsigned short a_digit = 0;
unsigned short b_digit = b;
if (i < (int)a.digits.size())
a_digit = a.digits[i];
mult_n_carry += a_digit * b_digit;
c.digits[i] = (mult_n_carry & 0xFFFF);
mult_n_carry >>= 16;
if (mult_n_carry != 0) {
putCarryInfront(c, mult_n_carry);
//std::cout << a.toString() << " x " << b << " == " << c.toString() << std::endl;
void BigInteger::shiftLeftByBase(BigInteger& b, const BigInteger& a, int times) {
b.digits.resize(a.digits.size() + times);
for (int i = 0; i < times; ++i) {
b.digits[i] = 0;
for (int i = 0; i < (int)a.digits.size(); ++i) {
b.digits[i + times] = a.digits[i];
void BigInteger::shiftRight(BigInteger& a) {
//std::cout << "shr " << a.toString() << " == ";
for (int i = 0; i < (int)a.digits.size(); ++i) {
a.digits[i] >>= 1;
if (i+1 < (int)a.digits.size()) {
if ((a.digits[i+1] & 0x1) != 0) {
a.digits[i] |= 0x8000;
//std::cout << a.toString() << std::endl;
void BigInteger::shiftLeft(BigInteger& a) {
bool lastBit = false;
for (int i = 0; i < (int)a.digits.size(); ++i) {
bool bit = (a.digits[i] & 0x8000) != 0;
a.digits[i] <<= 1;
if (lastBit)
a.digits[i] |= 1;
lastBit = bit;
if (lastBit) {
void BigInteger::putCarryInfront(BigInteger& a, unsigned short carry) {
BigInteger b;
b.negative = a.negative;
b.digits.resize(a.digits.size() + 1);
b.digits[a.digits.size()] = carry;
for (int i = 0; i < (int)a.digits.size(); ++i) {
b.digits[i] = a.digits[i];
void BigInteger::divideWithoutSign(BigInteger& c, BigInteger& d, const BigInteger& a, const BigInteger& b) {
BigInteger two("2");
BigInteger e = b;
BigInteger f("1");
BigInteger g = a;
BigInteger one("1");
while (cmpWithoutSign(g, e) >= 0) {
while (cmpWithoutSign(g, b) >= 0) {
g -= e;
c += f;
while (cmpWithoutSign(g, e) < 0) {
e = c;
e *= b;
f = a;
f -= e;
d = f;
BigInteger::BigInteger(const BigInteger& other) {
digits = other.digits;
negative = other.negative;
BigInteger::BigInteger(const char* other) {
negative = false;
BigInteger ten;
ten.digits[0] = 10;
const char* c = other;
bool make_negative = false;
if (*c == '-') {
make_negative = true;
while (*c != 0) {
BigInteger digit;
digit.digits[0] = *c - '0';
*this *= ten;
*this += digit;
negative = make_negative;
bool BigInteger::isOdd() const {
return (digits[0] & 0x1) != 0;
BigInteger& BigInteger::operator=(const BigInteger& other) {
if (this == &other) // handle self assignment
return *this;
digits = other.digits;
negative = other.negative;
return *this;
BigInteger& BigInteger::operator+=(const BigInteger& other) {
BigInteger result;
if (negative) {
if (other.negative) {
result.negative = true;
addWithoutSign(result, *this, other);
} else {
int a = cmpWithoutSign(*this, other);
if (a < 0) {
result.negative = false;
subWithoutSign(result, other, *this);
} else if (a > 0) {
result.negative = true;
subWithoutSign(result, *this, other);
} else {
result.negative = false;
} else {
if (other.negative) {
int a = cmpWithoutSign(*this, other);
if (a < 0) {
result.negative = true;
subWithoutSign(result, other, *this);
} else if (a > 0) {
result.negative = false;
subWithoutSign(result, *this, other);
} else {
result.negative = false;
} else {
result.negative = false;
addWithoutSign(result, *this, other);
negative = result.negative;
return *this;
BigInteger& BigInteger::operator-=(const BigInteger& other) {
BigInteger neg_other = other;
neg_other.negative = !neg_other.negative;
return *this += neg_other;
BigInteger& BigInteger::operator*=(const BigInteger& other) {
BigInteger result;
for (int i = 0; i < (int)digits.size(); ++i) {
BigInteger mult;
multByDigitWithoutSign(mult, other, digits[i]);
BigInteger shift;
shiftLeftByBase(shift, mult, i);
BigInteger add;
addWithoutSign(add, result, shift);
result = add;
if (negative != other.negative) {
result.negative = true;
} else {
result.negative = false;
//std::cout << toString() << " x " << other.toString() << " == " << result.toString() << std::endl;
negative = result.negative;
return *this;
BigInteger& BigInteger::operator/=(const BigInteger& other) {
BigInteger result, tmp;
divideWithoutSign(result, tmp, *this, other);
result.negative = (negative != other.negative);
negative = result.negative;
return *this;
BigInteger& BigInteger::operator%=(const BigInteger& other) {
BigInteger c, d;
divideWithoutSign(c, d, *this, other);
*this = d;
return *this;
bool BigInteger::operator>(const BigInteger& other) const {
if (negative) {
if (other.negative) {
return cmpWithoutSign(*this, other) < 0;
} else {
return false;
} else {
if (other.negative) {
return true;
} else {
return cmpWithoutSign(*this, other) > 0;
BigInteger& BigInteger::powAssignUnderMod(const BigInteger& exponent, const BigInteger& modulus) {
BigInteger zero("0");
BigInteger one("1");
BigInteger e = exponent;
BigInteger base = *this;
*this = one;
while (cmpWithoutSign(e, zero) != 0) {
//std::cout << e.toString() << " : " << toString() << " : " << base.toString() << std::endl;
if (e.isOdd()) {
*this *= base;
*this %= modulus;
base *= BigInteger(base);
base %= modulus;
return *this;
std::string BigInteger::toString() const {
std::ostringstream os;
if (negative)
os << "-";
BigInteger tmp = *this;
BigInteger zero("0");
BigInteger ten("10");
tmp.negative = false;
std::stack<char> s;
while (cmpWithoutSign(tmp, zero) != 0) {
BigInteger tmp2, tmp3;
divideWithoutSign(tmp2, tmp3, tmp, ten);
s.push((char)(tmp3.digits[0] + '0'));
tmp = tmp2;
while (!s.empty()) {
os <<;
for (int i = digits.size()-1; i >= 0; --i) {
os << digits[i];
if (i != 0) {
os << ",";
return os.str();
And an example usage.
BigInteger a("87682374682734687"), b("435983748957348957349857345"), c("2348927349872344")
// Will Calculate pow(87682374682734687, 435983748957348957349857345) % 2348927349872344
a.powAssignUnderMod(b, c);
Its fast too, and has unlimited number of digits.
Upvotes: 7
Reputation: 111280
Two things:
No, it does not, since at one point you have Your code does not work because at one point you have num = 2^16
and the num = ...
causes overflow. Use a bigger data type to hold this intermediate value.
How about taking modulo at every possible overflow oppertunity such as:
test = ((test % mod) * (num % mod)) % mod;
unsigned mod_pow(unsigned num, unsigned pow, unsigned mod)
unsigned long long test;
unsigned long long n = num;
for(test = 1; pow; pow >>= 1)
if (pow & 1)
test = ((test % mod) * (n % mod)) % mod;
n = ((n % mod) * (n % mod)) % mod;
return test; /* note this is potentially lossy */
int main(int argc, char* argv[])
/* (2 ^ 168277) % 673109 */
printf("%u\n", mod_pow(2, 168277, 673109));
return 0;
Upvotes: 4
Reputation: 279385
Exponentiation by squaring still "works" for modulo exponentiation. Your problem isn't that 2 ^ 168277
is an exceptionally large number, it's that one of your intermediate results is a fairly large number (bigger than 2^32), because 673109 is bigger than 2^16.
So I think the following will do. It's possible I've missed a detail, but the basic idea works, and this is how "real" crypto code might do large mod-exponentiation (although not with 32 and 64 bit numbers, rather with bignums that never have to get bigger than 2 * log (modulus)):
Obviously that's a bit awkward if your C++ implementation doesn't have a 64 bit integer, although you can always fake one.
There's an example on slide 22 here:, although it uses very small numbers (less than 2^16), so it may not illustrate anything you don't already know.
Your other example, 4000111222 ^ 3 % 1608
would work in your current code if you just reduce 4000111222
modulo 1608
before you start. 1608
is small enough that you can safely multiply any two mod-1608 numbers in a 32 bit int.
Upvotes: 8
Reputation: 8063
You could use following identity:
(a * b) (mod m) === (a (mod m)) * (b (mod m)) (mod m)
Try using it straightforward way and incrementally improve.
if (pow & 1)
test = ((test % mod) * (num % mod)) % mod;
num = ((num % mod) * (num % mod)) % mod;
Upvotes: -1