Reputation: 69
Consider the code segment below
Explanation - It performs division of two numbers n & d where size(d) <= 128 bits and 128 bits <= size(n) <= 256 bits. The data types are __uint128_t n[2], d[2], temp[8], q[2], temp_1
. The values of n & d are obtained from the previous section of the code. I have printed the values for reference.
printf("YEP 6\n");
printf("n[1] - %lx ",n[1]>>64); printf("%lx ",n[1]); printf("n[0] - %lx ",n[0]>>64); printf("%lx\n",n[0]);
q[1] = n[1]/d[0];
n[1] -= d[0]*q[1];
q[0] = 0;
printf("n[1] - %lx ",n[1]>>64); printf("%lx ",n[1]); printf("n[0] - %lx ",n[0]>>64); printf("%lx\n",n[0]);
printf("d[1] - %lx ",d[1]>>64); printf("%lx ",d[1]); printf("d[0] - %lx ",d[0]>>64); printf("%lx\n",d[0]);
temp[1] = d[0]>>64; temp[0] = d[0]&0xffffffffffffffff;
while(n[1])
{
printf("n[1] - %lx ",n[1]>>64); printf("%lx ",n[1]); printf("n[0] - %lx ",n[0]>>64); printf("%lx\n",n[0]);
q[0] += n[1];
temp[3] = n[1]>>64; temp[2] = n[1]&0xffffffffffffffff;
temp_1 = temp[0]*temp[2]; temp[4] = temp_1&0xffffffffffffffff; temp[5] = temp_1>>64; temp_1 = temp[1]*temp[2]; temp[5] += temp_1&0xffffffffffffffff; temp[6] = temp_1>>64;
temp_1 = temp[0]*temp[3]; temp[5] += temp_1&0xffffffffffffffff; temp[6] += temp_1>>64; temp_1 = temp[1]*temp[3]; temp[6] += temp_1&0xffffffffffffffff; temp[7] = temp_1>>64;
temp[6] += temp[5]>>64; temp[7] = (temp[7]<<64) + temp[6]; temp[6] = temp[5]<<64 + temp[4];
n[1] = n[1] - temp[7] - (temp[6] > n[0]); n[0] -= temp[6];
}
The output of this segment is
YEP 6
n[1] - 0 ffffffffffffffff n[0] - ffffffffffffffff fffffffffffffff
n[1] - 0 ffffffffffffffff n[0] - ffffffffffffffff fffffffffffffff
d[1] - 0 0 d[0] - f000000000000000 fffffffefffffc2f
n[1] - 0 1 n[0] - feeeeeeeeffda020 ff320000821ffff
n[1] - 0 1 n[0] - feeeeeeeeffda020 ff320000821ffff
n[1] - 0 1 n[0] - feeeeeeeeffda020 ff320000821ffff
n[1] - 0 1 n[0] - feeeeeeeeffda020 ff320000821ffff
n[1] - 0 1 n[0] - feeeeeeeeffda020 ff320000821ffff
n[1] - 0 1 n[0] - feeeeeeeeffda020 ff320000821ffff
n[1] - 0 1 n[0] - feeeeeeeeffda020 ff320000821ffff
n[1] - 0 1 n[0] - feeeeeeeeffda020 ff320000821ffff
n[1] - 0 1 n[0] - feeeeeeeeffda020 ff320000821ffff
n[1] - 0 1 n[0] - feeeeeeeeffda020 ff320000821ffff
n[1] - 0 1 n[0] - feeeeeeeeffda020 ff320000821ffff
n[1] - 0 1 n[0] - feeeeeeeeffda020 ff320000821ffff
n[1] - 0 1 n[0] - feeeeeeeeffda020 ff320000821ffff
n[1] - 0 1 n[0] - feeeeeeeeffda020 ff320000821ffff
n[1] - 0 1 n[0] - feeeeeeeeffda020 ff320000821ffff
The output goes on to infinity with n[1] - 0 1 n[0] - feeeeeeeeffda020 ff320000821ffff
repeating. Notice that the value of n[1] & n[0] has changed when it enters the loop even though I have not made any changes (I have performed subtraction with n[1] above which is why i printed the values before and after to denote no change). I have printed n[1] & n[0] before the start of the while loop and after the while loop to verify that.
This is causing no changes to n[1] overall because n[1] has been changed to 1
instead of 0xffffffffffffffff
and thus n[1] will not be changed by remaining code which is causing the infinite loop.
Why is this happening? Any memory corruption of sorts? Or anything with while loops i seem to have missed? I even changed to while(n[1] > 0) and no changes.
I am compiling using gcc/g++.
gcc (Ubuntu 11.3.0-1ubuntu1~22.04.1) 11.3.0
g++ (Ubuntu 11.3.0-1ubuntu1~22.04.1) 11.3.0
The Complete program
Explanation - This program is aimed to calculate the multiplicative inverse of a given number in a prime field only in C to be used in CUDA. So, mpz_invert() in GMP library or other libraries cannot be used. The bit shifting technique normally used for division of very large numbers is very slow as i have to perform trillions of such calculations. This is still to be tested whether it gives the correct answer or not. The issue was found when i was verifying. It compiles though..
#include <stdio.h>
#define UINT128(hi, lo) (((__uint128_t) (hi)) << 64 | (lo))
int main()
{
__uint128_t n[2], d[2], q[2], temp[8], temp_1, m[2], m1[2], k1[2], w[2];
__uint64_t value;
double helper;
int sign_1, sign_2, i; // 1 is +ve
n[1] = UINT128(0xffffffffffffffff, 0xffffffffffffffff); n[0] = UINT128(0xffffffffffffffff, 0xfffffffefffffc2f);
k1[1] = UINT128(0x0, 0xffffffffffffffff); k1[0] = UINT128(0xffffffffffffffff, 0x0fffffffffffffff);
d[1] = k1[1]; d[0] = k1[0]; m[1] = 0; m[0] = 0; m1[1] = 0; m1[0] = 1;// If k1 is not used here, you can shift d for k1;
sign_1 = 1; sign_2 = 1;
// printf("d[1] - %lx \n", d[1]>>64); printf("d[1] - %lx \n", d[1]);
while(d[1] != 0)
{
q[1] = n[1]/d[1]; q[0] = 0; printf("q[1] - %lx ",q[1]>>64); printf("%lx\n",q[1]);
temp[1] = d[0]>>64; temp[0] = d[0]&0xffffffffffffffff; // This is done to avoid multiple computations of d[1, 0]
while(q[1])
{
printf("YEP 1\n");
// Now, multiply q[1] with d[0] to get the overflow.. (temp[1]temp[0] X temp[3]temp[2]) = temp[7]temp[6]temp[5]temp[4]
temp[3] = q[1]>>64; temp[2] = q[1]&0xffffffffffffffff;
temp_1 = temp[0]*temp[2]; temp[4] = temp_1&0xffffffffffffffff; temp[5] = temp_1>>64; temp_1 = temp[1]*temp[2]; temp[5] += temp_1&0xffffffffffffffff; temp[6] = temp_1>>64; temp_1 = temp[0]*temp[3]; temp[5] += temp_1&0xffffffffffffffff; temp[6] += temp_1>>64; temp_1 = temp[1]*temp[3]; temp[6] += temp_1&0xffffffffffffffff; temp[7] = temp_1>>64;
temp[6] += temp[5]>>64; temp[7] = (temp[7]<<64) + temp[6];
if(d[1] >= temp[7])
{
printf("1 - %lx ",n[1]>>64); printf("%lx\n",n[1]); printf("temp[7] - %lx %lx\n",temp[7]>>64, temp[7]);
// Multiply with q[1] and check and subtract 1 if needed!!
temp[3] = q[1]*d[1] + temp[7]; temp[2] = (temp[5]<<64) + temp[4];// Wont temp[3] overflow??? It wont!!
q[0] += q[1];
if((temp[3] > n[1]) || ((temp[3] == n[1]) && (temp[2] > n[0])))
{
q[0]--;
// Subtract d[1, 0] from temp[3, 2] and put the value in temp[3, 2]
temp[3] = temp[3] - d[1] - (d[0] > temp[2]); temp[2] -= d[0];
}
// Subtract temp[3, 2] from n[1, 0] and put the remainder in n[1, 0]
n[1] = n[1] - temp[3] - (temp[2] > n[0]); n[0] -= temp[2];
printf("%lx ",n[1]); printf("%lx\n",n[0]);
}
else
{
printf("2 - n[1] %lx ",n[1]>>64); printf("%lx\n",n[1]); printf("temp[7] - %lx %lx\n",temp[7]>>64, temp[7]);
helper = (d[0]>>96);
helper = (0xffffffff/helper)*10;// Also check if 32 bit approx is enough to justify 128 bit behaviour
value = helper; printf("value - %lx\n",value);
temp[3] = d[1]*value + 10; temp[2] = temp[7]/temp[3]; temp[4] = temp[7]%temp[3]; printf("temp[3] - %lx ",(temp[2]*value + ((temp[4]*value)/temp[3]) + 1)>>64); printf("%lx\n",(temp[2]*value + ((temp[4]*value)/temp[3]) + 1));
q[1] -= (temp[2]*value + ((temp[4]*value)/temp[3]) + 1);
q[0] += q[1];
// Now multiply q[1] with d[1, 0] and subtract that from n[1, 0] and put it in n[1, 0].. (temp[1]temp[0] X temp[3]temp[2]) = temp[7]temp[6]temp[5]temp[4]
temp[3] = q[1]>>64; temp[2] = q[1]&0xffffffffffffffff;
temp_1 = temp[0]*temp[2]; temp[4] = temp_1&0xffffffffffffffff; temp[5] = temp_1>>64; temp_1 = temp[1]*temp[2]; temp[5] += temp_1&0xffffffffffffffff; temp[6] = temp_1>>64;
temp_1 = temp[0]*temp[3]; temp[5] += temp_1&0xffffffffffffffff; temp[6] += temp_1>>64; temp_1 = temp[1]*temp[3]; temp[6] += temp_1&0xffffffffffffffff; temp[7] = temp_1>>64;
temp[6] += temp[5]>>64; temp[7] = (temp[7]<<64) + temp[6];
temp[3] = q[1]*d[1] + temp[7]; temp[2] = (temp[5]<<64) + temp[4];// This must be less than n[1, 0].. helper makes sure that this value is <= n[1, 0]. get helper right!!!
n[1] = n[1] - temp[3] - (temp[2] > n[0]); n[0] -= temp[2];
printf("n[1] - %lx ",n[1]>>64); printf("%lx\n",n[1]);
printf("q[1] - %lx ",q[1]>>64); printf("%lx\n",q[1]);
}
q[1] = n[1]/d[1];
printf("YEP 2\n"); printf("q[1] - %lx ",q[1]>>64); printf("%lx\n",q[1]);
}
// n[1,0] contains the remainder.. Swap n & d.. No need for temp[0] & temp[1] now.. So use them!
printf("n[1] - %lx ",n[1]>>64); printf("%lx ",n[1]); printf("n[0] - %lx ",n[0]>>64); printf("%lx\n",n[0]);
temp[1] = n[1]; temp[0] = n[0]; n[1] = d[1]; n[0] = d[0]; d[1] = temp[1]; d[0] = temp[0];
printf("YEP 3\n"); printf("n[1] - %lx ",n[1]>>64); printf("%lx ",n[1]); printf("n[0] - %lx ",n[0]>>64); printf("%lx\n",n[0]);
printf("d[1] - %lx ",d[1]>>64); printf("%lx ",d[1]); printf("d[0] - %lx ",d[0]>>64); printf("%lx\n",d[0]);
// Calculate m1 = m - q[0]*m1;
// m1*q[0]
temp[5] = m1[1]>>64; temp[4] = m1[1]&0xffffffffffffffff; temp[3] = m1[0]>>64; temp[2] = m1[0]&0xffffffffffffffff;
temp[1] = q[0]>>64; temp[0] = q[0]&0xffffffffffffffff; // temp[5]temp[4]temp[3]temp[2]Xtemp[1]temp[0] = temp[7]temp[6]q[1]q[0]// Using this cuz its only free.. No need for q[1, 0] now!
temp_1 = temp[0]*temp[2]; q[0] = temp_1&0xffffffffffffffff; q[1] = temp_1>>64; temp_1 = temp[0]*temp[3]; q[1] += temp_1&0xffffffffffffffff; temp[6] = temp_1>>64;
temp_1 = temp[0]*temp[4]; temp[6] += temp_1&0xffffffffffffffff; temp[7] = temp_1>>64; temp_1 = temp[0]*temp[5]; temp[7] += temp_1&0xffffffffffffffff;
temp_1 = temp[1]*temp[2]; q[1] += temp_1&0xffffffffffffffff; temp[6] += temp_1>>64; temp_1 = temp[1]*temp[3]; temp[6] += temp_1&0xffffffffffffffff; temp[7] += temp_1>>64;
temp_1 = temp[1]*temp[4]; temp[7] += temp_1&0xffffffffffffffff;
temp[6] += q[1]>>64; temp[7] = (temp[7]<<64) + temp[6]; temp[6] = (q[1]<<64) + q[0];
sign_2 = !sign_2;
if(sign_1^sign_2)
{
temp[6] += m[0]; temp[7] += m[1] + (temp[6] < m[0]); m[1] = m1[1]; m[0] = m1[0]; m1[1] = temp[7]; m1[0] = temp[6];
}
else
{
// Use value to see which is greater.. if abs(m) > abs(temp), 1 else 0
value = (m[1] > temp[7])||((m[1] == temp[7]) && m[0] > temp[6]);
if(value)
{
// m[1, 0] - temp[7, 6]
temp[7] = m[1] - temp[7] - (temp[6] > m[0]); temp[6] = m[0] - temp[6]; m[1] = m1[1]; m[0] = m1[0]; m1[1] = temp[7]; m1[0] = temp[6]; sign_1 = !sign_1; sign_2 = !sign_2;
}
else
{
// temp[7, 6] - m[1, 0]
temp[7] = temp[7] - m[1] - (m[0] > temp[6]); temp[6] -= m[0]; m[1] = m1[1]; m[0] = m1[0]; m1[1] = temp[7]; m1[0] = temp[6]; sign_1 = sign_2;
}
}
}
printf("YEP 4\n");
if(n[1] != 0)
{
printf("YEP 6\n");
printf("n[1] - %lx ",n[1]>>64); printf("%lx ",n[1]); printf("n[0] - %lx ",n[0]>>64); printf("%lx\n",n[0]);
q[1] = n[1]/d[0];
n[1] -= d[0]*q[1];
q[0] = 0;
printf("n[1] - %lx ",n[1]>>64); printf("%lx ",n[1]); printf("n[0] - %lx ",n[0]>>64); printf("%lx\n",n[0]);
printf("d[1] - %lx ",d[1]>>64); printf("%lx ",d[1]); printf("d[0] - %lx ",d[0]>>64); printf("%lx\n",d[0]);
temp[1] = d[0]>>64; temp[0] = d[0]&0xffffffffffffffff;
while(n[1])
{
printf("n[1] - %lx ",n[1]>>64); printf("%lx ",n[1]); printf("n[0] - %lx ",n[0]>>64); printf("%lx\n",n[0]);
//printf("YEP 5\n");
q[0] += n[1];
// d[0] X n[1], then subtract from n[1]n[0]; temp[1]temp[0] X temp[3]temp[2] = temp[7]temp[6]temp[5]temp[4] = temp[7]temp[6]
temp[3] = n[1]>>64; temp[2] = n[1]&0xffffffffffffffff;
//printf("n[1] - %lx ",n[1]>>64); printf("%lx ",n[1]); printf("n[0] - %lx ",n[0]>>64); printf("%lx\n",n[0]);
//printf("temp[1] - %lx ",temp[1]); printf("temp[0] - %lx ",temp[0]); printf("temp[3] - %lx ",temp[3]); printf("temp[2] - %lx\n",temp[2]);
temp_1 = temp[0]*temp[2]; temp[4] = temp_1&0xffffffffffffffff; temp[5] = temp_1>>64; temp_1 = temp[1]*temp[2]; temp[5] += temp_1&0xffffffffffffffff; temp[6] = temp_1>>64;
temp_1 = temp[0]*temp[3]; temp[5] += temp_1&0xffffffffffffffff; temp[6] += temp_1>>64; temp_1 = temp[1]*temp[3]; temp[6] += temp_1&0xffffffffffffffff; temp[7] = temp_1>>64;
temp[6] += temp[5]>>64; temp[7] = (temp[7]<<64) + temp[6]; temp[6] = temp[5]<<64 + temp[4];
n[1] = n[1] - temp[7] - (temp[6] > n[0]); n[0] -= temp[6];
//printf("n[1] - %lx ",n[1]>>64); printf("%lx ",n[1]); printf("n[0] - %lx ",n[0]>>64); printf("%lx\n",n[0]);
//printf("temp[7] - %lx ",temp[7]>>64); printf("%lx ",temp[7]); printf("temp[6] - %lx ",temp[6]>>64); printf("%lx\n",temp[6]);
}
printf("n[1] - %lx ",n[1]>>64); printf("%lx ",n[1]); printf("n[0] - %lx ",n[0]>>64); printf("%lx\n",n[0]);
n[0] = d[1]; n[0] = d[0]; d[0] = d[1]; d[1] = 0;
// Now m1 = m - q[1, 0]*m1; This is 256*256 bit multiplication = 256 bits
temp[7] = q[1]>>64; temp[6] = q[1]&0xffffffffffffffff; temp[5] = q[0]>>64; temp[4] = q[0]&0xffffffffffffffff;
// temp[1]temp[0]q[1]q[0] X temp[7]temp[6]temp[5]temp[4] = temp[3]temp[2]temp[1]temp[0] = temp[7]temp[6]
temp[1] = m1[1]>>64; temp[0] = m1[1]&0xffffffffffffffff; q[1] = m1[0]>>64; q[0] = m1[0]&0xffffffffffffffff;
temp_1 = q[0]*temp[4]; temp[0] = temp_1&0xffffffffffffffff; temp[1] = temp_1>>64; temp_1 = q[1]*temp[4]; temp[1] += temp_1&0xffffffffffffffff; temp[2] = temp_1>>64;
temp_1 = temp[0]*temp[4]; temp[2] += temp_1&0xffffffffffffffff; temp[3] = temp_1>>64; temp_1 = temp[1]*temp[4]; temp[3] += temp_1&0xffffffffffffffff;
temp_1 = q[0]*temp[5]; temp[1] += temp_1&0xffffffffffffffff; temp[2] += temp_1>>64; temp_1 = q[1]*temp[5]; temp[2] += temp_1&0xffffffffffffffff; temp[3] += temp_1>>64;
temp_1 = temp[0]*temp[5]; temp[3] += temp_1&0xffffffffffffffff;
temp_1 = q[0]*temp[6]; temp[2] += temp_1&0xffffffffffffffff; temp[3] += temp_1>>64; temp_1 = q[1]*temp[6]; temp[3] += temp_1&0xffffffffffffffff;
temp_1 = q[0]*temp[7]; temp[3] += temp_1&0xffffffffffffffff;
temp[2] += temp[1]>>64; temp[7] = (temp[3]<<64) + temp[2]; temp[6] = (temp[1]<<64) + temp[0];
sign_2 = !sign_2;
if(sign_1^sign_2)
{
temp[6] += m[0]; temp[7] += m[1] + (temp[6] < m[0]); m[1] = m1[1]; m[0] = m1[0]; m1[1] = temp[7]; m1[0] = temp[6];
}
else
{
// Use value to see which is greater.. if abs(m) > abs(temp), 1 else 0
value = (m[1] > temp[7])||((m[1] == temp[7]) && m[0] > temp[6]);
if(value)
{
// m[1, 0] - temp[7, 6]
temp[7] = m[1] - temp[7] - (temp[6] > m[0]); temp[6] = m[0] - temp[6]; m[1] = m1[1]; m[0] = m1[0]; m1[1] = temp[7]; m1[0] = temp[6]; sign_1 = !sign_1; sign_2 = !sign_2;
}
else
{
// temp[7, 6] - m[1, 0]
temp[7] = temp[7] - m[1] - (m[0] > temp[6]); temp[6] -= m[0]; m[1] = m1[1]; m[0] = m1[0]; m1[1] = temp[7]; m1[0] = temp[6]; sign_1 = sign_2;
}
}
printf("YEP 5\n");
}
while(d[0] != 1)
{
q[0] = n[0]/d[0];
temp_1 = n[0]%d[0];
n[0] = d[0]; d[0] = temp_1;
printf("YEP 6\n");
// Calculate m1 = m - q[0]*m1;
// m1*q[0]
temp[5] = m1[1]>>64; temp[4] = m1[1]&0xffffffffffffffff; temp[3] = m1[0]>>64; temp[2] = m1[0]&0xffffffffffffffff;
temp[1] = q[0]>>64; temp[0] = q[0]&0xffffffffffffffff; // temp[5]temp[4]temp[3]temp[2]Xtemp[1]temp[0] = temp[7]temp[6]q[1]q[0]// Using this cuz its only free.. No need for q[1, 0] now!
temp_1 = temp[0]*temp[2]; q[0] = temp_1&0xffffffffffffffff; q[1] = temp_1>>64; temp_1 = temp[0]*temp[3]; q[1] += temp_1&0xffffffffffffffff; temp[6] = temp_1>>64;
temp_1 = temp[0]*temp[4]; temp[6] += temp_1&0xffffffffffffffff; temp[7] = temp_1>>64; temp_1 = temp[0]*temp[5]; temp[7] += temp_1&0xffffffffffffffff;
temp_1 = temp[1]*temp[2]; q[1] += temp_1&0xffffffffffffffff; temp[6] += temp_1>>64; temp_1 = temp[1]*temp[3]; temp[6] += temp_1&0xffffffffffffffff; temp[7] += temp_1>>64;
temp_1 = temp[1]*temp[4]; temp[7] += temp_1&0xffffffffffffffff;
temp[6] += q[1]>>64; temp[7] = (temp[7]<<64) + temp[6]; temp[6] = (q[1]<<64) + q[0];
sign_2 = !sign_2;
if(sign_1^sign_2)
{
temp[6] += m[0]; temp[7] += m[1] + (temp[6] < m[0]); m[1] = m1[1]; m[0] = m1[0]; m1[1] = temp[7]; m1[0] = temp[6];
}
else
{
// Use value to see which is greater.. if abs(m) > abs(temp), 1 else 0
value = (m[1] > temp[7])||((m[1] == temp[7]) && m[0] > temp[6]);
if(value)
{
// m[1, 0] - temp[7, 6]
temp[7] = m[1] - temp[7] - (temp[6] > m[0]); temp[6] = m[0] - temp[6]; m[1] = m1[1]; m[0] = m1[0]; m1[1] = temp[7]; m1[0] = temp[6]; sign_1 = !sign_1; sign_2 = !sign_2;
}
else
{
// temp[7, 6] - m[1, 0]
temp[7] = temp[7] - m[1] - (m[0] > temp[6]); temp[6] -= m[0]; m[1] = m1[1]; m[0] = m1[0]; m1[1] = temp[7]; m1[0] = temp[6]; sign_1 = sign_2;
}
}
}
if(!sign_2)
{
// add p to m[1, 0]
m[0] += UINT128(0xffffffffffffffff, 0xfffffffefffffc2f); m[1] += UINT128(0xffffffffffffffff, 0xffffffffffffffff) + (m[0] < UINT128(0xffffffffffffffff, 0xfffffffefffffc2f));
}
printf("%lx\n", m[1]>>64);
printf("%lx\n", m[1]&0xffffffffffffffff);
printf("%lx\n", m[0]>>64);
printf("%lx\n", m[0]&0xffffffffffffffff);
return 0;
}
Upvotes: 0
Views: 56