Vincent
Vincent

Reputation: 60451

(n*2-1)%p: avoiding the use of 64 bits when n and p are 32 bits

Consider the following function:

inline unsigned int f(unsigned int n, unsigned int p) 
{
    return (n*2-1)%p;
}

Now suppose that n (and p) are greater than std::numeric_limits<int>::max().

For example f(4294967295U, 4294967291U).

The mathematical result is 7 but the function will return 2, because n*2 will overflow.

Then the solution is simple: we just have to use 64 bits integer instead. Assuming that the declaration of the function has to stay the same:

inline unsigned int f(unsigned int n, unsigned int p) 
{
    return (static_cast<unsigned long long int>(n)*2-1)%p;
}

Everything is fine. At least in principle. The problem is that this function will be called millions of times in my code (I mean the overflowing version), and 64 bits modulus is way slower than the 32 bits version (see here for example).

The question is the following: is there any trick (mathematical or algorithmic) to avoid to execute a 64 bits version of the modulus operation. And what would be a new version of f using this trick? (keeping the same declaration).

Upvotes: 3

Views: 193

Answers (3)

Columbo
Columbo

Reputation: 60999

FWIW, this version seems to be avoid any overflows:

std::uint32_t f(std::uint32_t n, std::uint32_t p) 
{
    auto m = n%p;
    if (m <= p/2) {
        return (m==0)*p+2*m-1;
    }
    return p-2*(p-m)-1;
}

Demo. The idea is that if an overflow would occur in 2*m-1, we can work with p-2*(p-m)-1, which avoids this by multiplying 2 with the modular additive inverse instead.

Upvotes: 1

Ishamael
Ishamael

Reputation: 12795

We know that p is less than max, then n % p is less than max. They are both unsigned, that means that n % p is positive, and smaller than p. Unsigned overflow is well-defined, so if n % p * 2 exceeds p, we can compute it as n % p - p + n % p, which will not overflow, so together it will look like this:

unsigned m = n % p;
unsigned r;
if (p - m < m) // m * 2 > p
    r = m - p + m;
else // m * 2 <= p
    r = m * 2;

// subtract 1, account for the fact that r can be 0
if (r == 0) r = p - 1;
else r = r - 1;
return r % p;

Note that you can avoid the last modulus, because we know that r doesn't exceed p * 2 (it is at most m * 2, and m doesn't exceed p), so the last line can be rewritten as

return r >= p ? r - p : r

Which brings the number of modulus operations to 1.

Upvotes: 3

user555045
user555045

Reputation: 64913

Even though I dislike dealing with AT&T syntax and GCC's "extended asm constraints", I think this works (it worked in my, admittedly limited, tests)

uint32_t f(uint32_t n, uint32_t p)
{
    uint32_t res;
    asm (
      "xorl %%edx, %%edx\n\t"
      "addl %%eax, %%eax\n\t"
      "adcl %%edx, %%edx\n\t"
      "subl $1, %%eax\n\t"
      "sbbl $0, %%edx\n\t"
      "divl %1"
      : "=d"(res)
      : "S"(p), "a"(n)
      : 
      );
  return res;
}

The constraints may be unnecessarily strict or wrong, I don't know. It seemed to work.

The idea here is to do a regular 32bit division, which actually takes a 64bit dividend. It only works if the quotient will fit in 32 bits (otherwise overflow is signaled), which is always true under the circumstances (p at least 2, n not zero). The stuff before the division handles the times 2 (with overflow into edx, the "high half"), then the "subtract 1" with potential borrow. The "=d" output thing makes it take the remainder as result. "a"(n) puts n in eax (letting it choose an other register doesn't help, the division will take an input in edx:eax anyway). "S"(p) could probably be "r"(p) (seems to work) but I'm not sure enough to trust it.

Upvotes: 1

Related Questions