Knm
Knm

Reputation: 69

Code Segment changing value of variable without any human interference

Consider the code segment below
Explanation - It performs division of two numbers n & d where size(d) <= 128 bits and 128 bits <= size(n) <= 256 bits. The data types are __uint128_t n[2], d[2], temp[8], q[2], temp_1. The values of n & d are obtained from the previous section of the code. I have printed the values for reference.

printf("YEP 6\n");
printf("n[1] - %lx ",n[1]>>64); printf("%lx ",n[1]); printf("n[0] - %lx ",n[0]>>64); printf("%lx\n",n[0]);
q[1] = n[1]/d[0];
n[1] -= d[0]*q[1];
q[0] = 0;
printf("n[1] - %lx ",n[1]>>64); printf("%lx ",n[1]); printf("n[0] - %lx ",n[0]>>64); printf("%lx\n",n[0]);
printf("d[1] - %lx ",d[1]>>64); printf("%lx ",d[1]); printf("d[0] - %lx ",d[0]>>64); printf("%lx\n",d[0]);
temp[1] = d[0]>>64; temp[0] = d[0]&0xffffffffffffffff;

while(n[1])
{ 
    printf("n[1] - %lx ",n[1]>>64); printf("%lx ",n[1]); printf("n[0] - %lx ",n[0]>>64); printf("%lx\n",n[0]);
    q[0] += n[1];
    temp[3] = n[1]>>64; temp[2] = n[1]&0xffffffffffffffff;
    temp_1 = temp[0]*temp[2]; temp[4] = temp_1&0xffffffffffffffff; temp[5] = temp_1>>64; temp_1 = temp[1]*temp[2]; temp[5] += temp_1&0xffffffffffffffff; temp[6] = temp_1>>64; 
    temp_1 = temp[0]*temp[3]; temp[5] += temp_1&0xffffffffffffffff; temp[6] += temp_1>>64; temp_1 = temp[1]*temp[3]; temp[6] += temp_1&0xffffffffffffffff; temp[7] = temp_1>>64; 
    temp[6] += temp[5]>>64; temp[7] = (temp[7]<<64) + temp[6]; temp[6] = temp[5]<<64 + temp[4];
    n[1] = n[1] - temp[7] - (temp[6] > n[0]); n[0] -= temp[6];
}

The output of this segment is

YEP 6
n[1] - 0 ffffffffffffffff n[0] - ffffffffffffffff fffffffffffffff
n[1] - 0 ffffffffffffffff n[0] - ffffffffffffffff fffffffffffffff
d[1] - 0 0 d[0] - f000000000000000 fffffffefffffc2f
n[1] - 0 1 n[0] - feeeeeeeeffda020 ff320000821ffff
n[1] - 0 1 n[0] - feeeeeeeeffda020 ff320000821ffff
n[1] - 0 1 n[0] - feeeeeeeeffda020 ff320000821ffff
n[1] - 0 1 n[0] - feeeeeeeeffda020 ff320000821ffff
n[1] - 0 1 n[0] - feeeeeeeeffda020 ff320000821ffff
n[1] - 0 1 n[0] - feeeeeeeeffda020 ff320000821ffff
n[1] - 0 1 n[0] - feeeeeeeeffda020 ff320000821ffff
n[1] - 0 1 n[0] - feeeeeeeeffda020 ff320000821ffff
n[1] - 0 1 n[0] - feeeeeeeeffda020 ff320000821ffff
n[1] - 0 1 n[0] - feeeeeeeeffda020 ff320000821ffff
n[1] - 0 1 n[0] - feeeeeeeeffda020 ff320000821ffff
n[1] - 0 1 n[0] - feeeeeeeeffda020 ff320000821ffff
n[1] - 0 1 n[0] - feeeeeeeeffda020 ff320000821ffff
n[1] - 0 1 n[0] - feeeeeeeeffda020 ff320000821ffff
n[1] - 0 1 n[0] - feeeeeeeeffda020 ff320000821ffff

The output goes on to infinity with n[1] - 0 1 n[0] - feeeeeeeeffda020 ff320000821ffff repeating. Notice that the value of n[1] & n[0] has changed when it enters the loop even though I have not made any changes (I have performed subtraction with n[1] above which is why i printed the values before and after to denote no change). I have printed n[1] & n[0] before the start of the while loop and after the while loop to verify that.
This is causing no changes to n[1] overall because n[1] has been changed to 1 instead of 0xffffffffffffffff and thus n[1] will not be changed by remaining code which is causing the infinite loop.
Why is this happening? Any memory corruption of sorts? Or anything with while loops i seem to have missed? I even changed to while(n[1] > 0) and no changes.
I am compiling using gcc/g++.

gcc (Ubuntu 11.3.0-1ubuntu1~22.04.1) 11.3.0
g++ (Ubuntu 11.3.0-1ubuntu1~22.04.1) 11.3.0

The Complete program
Explanation - This program is aimed to calculate the multiplicative inverse of a given number in a prime field only in C to be used in CUDA. So, mpz_invert() in GMP library or other libraries cannot be used. The bit shifting technique normally used for division of very large numbers is very slow as i have to perform trillions of such calculations. This is still to be tested whether it gives the correct answer or not. The issue was found when i was verifying. It compiles though..

#include <stdio.h>
#define UINT128(hi, lo) (((__uint128_t) (hi)) << 64 | (lo))
int main()
{
    __uint128_t n[2], d[2], q[2], temp[8], temp_1, m[2], m1[2], k1[2], w[2];
    __uint64_t value;
    double helper;
    int sign_1, sign_2, i; // 1 is +ve

    n[1] = UINT128(0xffffffffffffffff, 0xffffffffffffffff); n[0] = UINT128(0xffffffffffffffff, 0xfffffffefffffc2f); 
    k1[1] = UINT128(0x0, 0xffffffffffffffff); k1[0] = UINT128(0xffffffffffffffff, 0x0fffffffffffffff);


    d[1] = k1[1]; d[0] = k1[0]; m[1] = 0; m[0] = 0; m1[1] = 0; m1[0] = 1;// If k1 is not used here, you can shift d for k1; 
    sign_1 = 1; sign_2 = 1;

    // printf("d[1] - %lx \n", d[1]>>64); printf("d[1] - %lx \n", d[1]);

    while(d[1] != 0)
    {

        q[1] = n[1]/d[1]; q[0] = 0; printf("q[1] - %lx ",q[1]>>64); printf("%lx\n",q[1]); 
        temp[1] = d[0]>>64; temp[0] = d[0]&0xffffffffffffffff; // This is done to avoid multiple computations of d[1, 0]
        while(q[1])
        {
            printf("YEP 1\n");
            // Now, multiply q[1] with d[0] to get the overflow.. (temp[1]temp[0] X temp[3]temp[2]) = temp[7]temp[6]temp[5]temp[4]
            temp[3] = q[1]>>64; temp[2] = q[1]&0xffffffffffffffff;
            temp_1 = temp[0]*temp[2]; temp[4] = temp_1&0xffffffffffffffff; temp[5] = temp_1>>64; temp_1 = temp[1]*temp[2]; temp[5] += temp_1&0xffffffffffffffff; temp[6] = temp_1>>64; temp_1 = temp[0]*temp[3]; temp[5] += temp_1&0xffffffffffffffff; temp[6] += temp_1>>64; temp_1 = temp[1]*temp[3]; temp[6] += temp_1&0xffffffffffffffff; temp[7] = temp_1>>64; 
            temp[6] += temp[5]>>64; temp[7] = (temp[7]<<64) + temp[6];
            if(d[1] >= temp[7])
            {
                printf("1 - %lx ",n[1]>>64); printf("%lx\n",n[1]); printf("temp[7] - %lx %lx\n",temp[7]>>64, temp[7]);
                // Multiply with q[1] and check and subtract 1 if needed!!
                temp[3] = q[1]*d[1] + temp[7]; temp[2] = (temp[5]<<64) + temp[4];// Wont temp[3] overflow??? It wont!!
                q[0] += q[1];
                if((temp[3] > n[1]) || ((temp[3] == n[1]) && (temp[2] > n[0])))
                {
                    q[0]--;
                    // Subtract d[1, 0] from temp[3, 2] and put the value in temp[3, 2]
                    temp[3] = temp[3] - d[1] - (d[0] > temp[2]); temp[2] -= d[0];
                }
                // Subtract temp[3, 2] from n[1, 0] and put the remainder in n[1, 0]

                n[1] = n[1] - temp[3] - (temp[2] > n[0]); n[0] -= temp[2];
                printf("%lx ",n[1]); printf("%lx\n",n[0]);
            }
            else
            {
                printf("2 - n[1] %lx ",n[1]>>64); printf("%lx\n",n[1]); printf("temp[7] - %lx %lx\n",temp[7]>>64, temp[7]);
                helper = (d[0]>>96);
                helper = (0xffffffff/helper)*10;// Also check if 32 bit approx is enough to justify 128 bit behaviour
                value = helper; printf("value - %lx\n",value);
                temp[3] = d[1]*value + 10; temp[2] = temp[7]/temp[3]; temp[4] = temp[7]%temp[3]; printf("temp[3] - %lx ",(temp[2]*value + ((temp[4]*value)/temp[3]) + 1)>>64); printf("%lx\n",(temp[2]*value + ((temp[4]*value)/temp[3]) + 1));
                q[1] -= (temp[2]*value + ((temp[4]*value)/temp[3]) + 1);
                q[0] += q[1];
                // Now multiply q[1] with d[1, 0] and subtract that from n[1, 0] and put it in n[1, 0]..  (temp[1]temp[0] X temp[3]temp[2]) = temp[7]temp[6]temp[5]temp[4]
                temp[3] = q[1]>>64; temp[2] = q[1]&0xffffffffffffffff;
                temp_1 = temp[0]*temp[2]; temp[4] = temp_1&0xffffffffffffffff; temp[5] = temp_1>>64; temp_1 = temp[1]*temp[2]; temp[5] += temp_1&0xffffffffffffffff; temp[6] = temp_1>>64; 
                temp_1 = temp[0]*temp[3]; temp[5] += temp_1&0xffffffffffffffff; temp[6] += temp_1>>64; temp_1 = temp[1]*temp[3]; temp[6] += temp_1&0xffffffffffffffff; temp[7] = temp_1>>64; 
                temp[6] += temp[5]>>64; temp[7] = (temp[7]<<64) + temp[6];
                temp[3] = q[1]*d[1] + temp[7]; temp[2] = (temp[5]<<64) + temp[4];// This must be less than n[1, 0].. helper makes sure that this value is <= n[1, 0]. get helper right!!!
                n[1] = n[1] - temp[3] - (temp[2] > n[0]); n[0] -= temp[2];
                printf("n[1] - %lx ",n[1]>>64); printf("%lx\n",n[1]);
                printf("q[1] - %lx ",q[1]>>64); printf("%lx\n",q[1]);
            }
            q[1] = n[1]/d[1]; 
            printf("YEP 2\n"); printf("q[1] - %lx ",q[1]>>64); printf("%lx\n",q[1]);

        }

        // n[1,0] contains the remainder.. Swap n & d.. No need for temp[0] & temp[1] now.. So use them!
        printf("n[1] - %lx ",n[1]>>64); printf("%lx ",n[1]); printf("n[0] - %lx ",n[0]>>64); printf("%lx\n",n[0]);
        temp[1] = n[1]; temp[0] = n[0]; n[1] = d[1]; n[0] = d[0]; d[1] = temp[1]; d[0] = temp[0];
        printf("YEP 3\n"); printf("n[1] - %lx ",n[1]>>64); printf("%lx ",n[1]); printf("n[0] - %lx ",n[0]>>64); printf("%lx\n",n[0]);
        printf("d[1] - %lx ",d[1]>>64); printf("%lx ",d[1]); printf("d[0] - %lx ",d[0]>>64); printf("%lx\n",d[0]);
        // Calculate m1 = m - q[0]*m1;
        // m1*q[0]
        temp[5] = m1[1]>>64; temp[4] = m1[1]&0xffffffffffffffff; temp[3] = m1[0]>>64; temp[2] = m1[0]&0xffffffffffffffff; 
        temp[1] = q[0]>>64; temp[0] = q[0]&0xffffffffffffffff; // temp[5]temp[4]temp[3]temp[2]Xtemp[1]temp[0] = temp[7]temp[6]q[1]q[0]// Using this cuz its only free.. No need for q[1, 0] now!
        temp_1 = temp[0]*temp[2]; q[0] = temp_1&0xffffffffffffffff; q[1] = temp_1>>64; temp_1 = temp[0]*temp[3]; q[1] += temp_1&0xffffffffffffffff; temp[6] = temp_1>>64;
        temp_1 = temp[0]*temp[4]; temp[6] += temp_1&0xffffffffffffffff; temp[7] = temp_1>>64; temp_1 = temp[0]*temp[5]; temp[7] += temp_1&0xffffffffffffffff;
        temp_1 = temp[1]*temp[2]; q[1] += temp_1&0xffffffffffffffff; temp[6] += temp_1>>64; temp_1 = temp[1]*temp[3]; temp[6] += temp_1&0xffffffffffffffff; temp[7] += temp_1>>64;
        temp_1 = temp[1]*temp[4]; temp[7] += temp_1&0xffffffffffffffff;
        temp[6] += q[1]>>64; temp[7] = (temp[7]<<64) + temp[6]; temp[6] = (q[1]<<64) + q[0];

        sign_2 = !sign_2;
        if(sign_1^sign_2)
        {
            temp[6] += m[0]; temp[7] += m[1] + (temp[6] < m[0]); m[1] = m1[1]; m[0] = m1[0]; m1[1] = temp[7]; m1[0] = temp[6]; 
        }
        else
        {
            // Use value to see which is greater.. if abs(m) > abs(temp), 1 else 0
            value = (m[1] > temp[7])||((m[1] == temp[7]) && m[0] > temp[6]);
            if(value)
            {
                // m[1, 0] - temp[7, 6]
                temp[7] = m[1] - temp[7] - (temp[6] > m[0]); temp[6] = m[0] - temp[6]; m[1] = m1[1]; m[0] = m1[0]; m1[1] = temp[7]; m1[0] = temp[6]; sign_1 = !sign_1; sign_2 = !sign_2;
            }
            else
            {
                // temp[7, 6] - m[1, 0]
                temp[7] = temp[7] - m[1] - (m[0] > temp[6]); temp[6] -= m[0]; m[1] = m1[1]; m[0] = m1[0]; m1[1] = temp[7]; m1[0] = temp[6]; sign_1 = sign_2;
            }
        }
    }




    printf("YEP 4\n");

    if(n[1] != 0)
    {
        printf("YEP 6\n");
        printf("n[1] - %lx ",n[1]>>64); printf("%lx ",n[1]); printf("n[0] - %lx ",n[0]>>64); printf("%lx\n",n[0]);
        q[1] = n[1]/d[0];
        n[1] -= d[0]*q[1];
        q[0] = 0;
        printf("n[1] - %lx ",n[1]>>64); printf("%lx ",n[1]); printf("n[0] - %lx ",n[0]>>64); printf("%lx\n",n[0]);
        printf("d[1] - %lx ",d[1]>>64); printf("%lx ",d[1]); printf("d[0] - %lx ",d[0]>>64); printf("%lx\n",d[0]);
        temp[1] = d[0]>>64; temp[0] = d[0]&0xffffffffffffffff;

        while(n[1])
        { 
            printf("n[1] - %lx ",n[1]>>64); printf("%lx ",n[1]); printf("n[0] - %lx ",n[0]>>64); printf("%lx\n",n[0]);
            //printf("YEP 5\n");
            q[0] += n[1];
            // d[0] X n[1], then subtract from n[1]n[0]; temp[1]temp[0] X temp[3]temp[2] = temp[7]temp[6]temp[5]temp[4] = temp[7]temp[6]
            temp[3] = n[1]>>64; temp[2] = n[1]&0xffffffffffffffff;
            //printf("n[1] - %lx ",n[1]>>64); printf("%lx ",n[1]); printf("n[0] - %lx ",n[0]>>64); printf("%lx\n",n[0]);
            //printf("temp[1] - %lx ",temp[1]); printf("temp[0] - %lx ",temp[0]); printf("temp[3] - %lx ",temp[3]); printf("temp[2] - %lx\n",temp[2]);
            temp_1 = temp[0]*temp[2]; temp[4] = temp_1&0xffffffffffffffff; temp[5] = temp_1>>64; temp_1 = temp[1]*temp[2]; temp[5] += temp_1&0xffffffffffffffff; temp[6] = temp_1>>64; 
            temp_1 = temp[0]*temp[3]; temp[5] += temp_1&0xffffffffffffffff; temp[6] += temp_1>>64; temp_1 = temp[1]*temp[3]; temp[6] += temp_1&0xffffffffffffffff; temp[7] = temp_1>>64; 

            temp[6] += temp[5]>>64; temp[7] = (temp[7]<<64) + temp[6]; temp[6] = temp[5]<<64 + temp[4];
            
            n[1] = n[1] - temp[7] - (temp[6] > n[0]); n[0] -= temp[6];
            //printf("n[1] - %lx ",n[1]>>64); printf("%lx ",n[1]); printf("n[0] - %lx ",n[0]>>64); printf("%lx\n",n[0]);
            //printf("temp[7] - %lx ",temp[7]>>64); printf("%lx ",temp[7]); printf("temp[6] - %lx ",temp[6]>>64); printf("%lx\n",temp[6]);
        }

        printf("n[1] - %lx ",n[1]>>64); printf("%lx ",n[1]); printf("n[0] - %lx ",n[0]>>64); printf("%lx\n",n[0]);

        n[0] = d[1]; n[0] = d[0]; d[0] = d[1]; d[1] = 0;
        // Now m1 = m - q[1, 0]*m1; This is 256*256 bit multiplication = 256 bits
        temp[7] = q[1]>>64; temp[6] = q[1]&0xffffffffffffffff; temp[5] = q[0]>>64; temp[4] = q[0]&0xffffffffffffffff; 
        // temp[1]temp[0]q[1]q[0] X temp[7]temp[6]temp[5]temp[4] = temp[3]temp[2]temp[1]temp[0] = temp[7]temp[6]
        temp[1] = m1[1]>>64; temp[0] = m1[1]&0xffffffffffffffff; q[1] = m1[0]>>64; q[0] = m1[0]&0xffffffffffffffff; 

        temp_1 = q[0]*temp[4]; temp[0] = temp_1&0xffffffffffffffff; temp[1] = temp_1>>64; temp_1 = q[1]*temp[4]; temp[1] += temp_1&0xffffffffffffffff; temp[2] = temp_1>>64; 
        temp_1 = temp[0]*temp[4]; temp[2] += temp_1&0xffffffffffffffff; temp[3] = temp_1>>64; temp_1 = temp[1]*temp[4]; temp[3] += temp_1&0xffffffffffffffff; 
        temp_1 = q[0]*temp[5]; temp[1] += temp_1&0xffffffffffffffff; temp[2] += temp_1>>64; temp_1 = q[1]*temp[5]; temp[2] += temp_1&0xffffffffffffffff; temp[3] += temp_1>>64; 
        temp_1 = temp[0]*temp[5]; temp[3] += temp_1&0xffffffffffffffff;
        temp_1 = q[0]*temp[6]; temp[2] += temp_1&0xffffffffffffffff; temp[3] += temp_1>>64; temp_1 = q[1]*temp[6]; temp[3] += temp_1&0xffffffffffffffff;
        temp_1 = q[0]*temp[7]; temp[3] += temp_1&0xffffffffffffffff;
        temp[2] += temp[1]>>64; temp[7] = (temp[3]<<64) + temp[2]; temp[6] = (temp[1]<<64) + temp[0];

        sign_2 = !sign_2;
        if(sign_1^sign_2)
        {
            temp[6] += m[0]; temp[7] += m[1] + (temp[6] < m[0]); m[1] = m1[1]; m[0] = m1[0]; m1[1] = temp[7]; m1[0] = temp[6]; 
        }
        else
        {
            // Use value to see which is greater.. if abs(m) > abs(temp), 1 else 0
            value = (m[1] > temp[7])||((m[1] == temp[7]) && m[0] > temp[6]);
            if(value)
            {
                // m[1, 0] - temp[7, 6]
                temp[7] = m[1] - temp[7] - (temp[6] > m[0]); temp[6] = m[0] - temp[6]; m[1] = m1[1]; m[0] = m1[0]; m1[1] = temp[7]; m1[0] = temp[6]; sign_1 = !sign_1; sign_2 = !sign_2;
            }
            else
            {
                // temp[7, 6] - m[1, 0]
                temp[7] = temp[7] - m[1] - (m[0] > temp[6]); temp[6] -= m[0]; m[1] = m1[1]; m[0] = m1[0]; m1[1] = temp[7]; m1[0] = temp[6]; sign_1 = sign_2;
            }
        }

    printf("YEP 5\n");
    }



    while(d[0] != 1)
    {
        q[0] = n[0]/d[0];
        temp_1 = n[0]%d[0];
        n[0] = d[0]; d[0] = temp_1;
            printf("YEP 6\n");
        // Calculate m1 = m - q[0]*m1;
        // m1*q[0]
        temp[5] = m1[1]>>64; temp[4] = m1[1]&0xffffffffffffffff; temp[3] = m1[0]>>64; temp[2] = m1[0]&0xffffffffffffffff; 
        temp[1] = q[0]>>64; temp[0] = q[0]&0xffffffffffffffff; // temp[5]temp[4]temp[3]temp[2]Xtemp[1]temp[0] = temp[7]temp[6]q[1]q[0]// Using this cuz its only free.. No need for q[1, 0] now!
        temp_1 = temp[0]*temp[2]; q[0] = temp_1&0xffffffffffffffff; q[1] = temp_1>>64; temp_1 = temp[0]*temp[3]; q[1] += temp_1&0xffffffffffffffff; temp[6] = temp_1>>64;
        temp_1 = temp[0]*temp[4]; temp[6] += temp_1&0xffffffffffffffff; temp[7] = temp_1>>64; temp_1 = temp[0]*temp[5]; temp[7] += temp_1&0xffffffffffffffff;
        temp_1 = temp[1]*temp[2]; q[1] += temp_1&0xffffffffffffffff; temp[6] += temp_1>>64; temp_1 = temp[1]*temp[3]; temp[6] += temp_1&0xffffffffffffffff; temp[7] += temp_1>>64;
        temp_1 = temp[1]*temp[4]; temp[7] += temp_1&0xffffffffffffffff;
        temp[6] += q[1]>>64; temp[7] = (temp[7]<<64) + temp[6]; temp[6] = (q[1]<<64) + q[0];

        sign_2 = !sign_2;
        if(sign_1^sign_2)
        {
            temp[6] += m[0]; temp[7] += m[1] + (temp[6] < m[0]); m[1] = m1[1]; m[0] = m1[0]; m1[1] = temp[7]; m1[0] = temp[6]; 
        }
        else
        {
            // Use value to see which is greater.. if abs(m) > abs(temp), 1 else 0
            value = (m[1] > temp[7])||((m[1] == temp[7]) && m[0] > temp[6]);
            if(value)
            {
                // m[1, 0] - temp[7, 6]
                temp[7] = m[1] - temp[7] - (temp[6] > m[0]); temp[6] = m[0] - temp[6]; m[1] = m1[1]; m[0] = m1[0]; m1[1] = temp[7]; m1[0] = temp[6]; sign_1 = !sign_1; sign_2 = !sign_2;
            }
            else
            {
                // temp[7, 6] - m[1, 0]
                temp[7] = temp[7] - m[1] - (m[0] > temp[6]); temp[6] -= m[0]; m[1] = m1[1]; m[0] = m1[0]; m1[1] = temp[7]; m1[0] = temp[6]; sign_1 = sign_2;
            }
        }
    }

    if(!sign_2)
    {
        // add p to m[1, 0]
        m[0] += UINT128(0xffffffffffffffff, 0xfffffffefffffc2f); m[1] += UINT128(0xffffffffffffffff, 0xffffffffffffffff) + (m[0] < UINT128(0xffffffffffffffff, 0xfffffffefffffc2f));
    }

    printf("%lx\n", m[1]>>64);
    printf("%lx\n", m[1]&0xffffffffffffffff);
    printf("%lx\n", m[0]>>64);
    printf("%lx\n", m[0]&0xffffffffffffffff);

    return 0;
}

Upvotes: 0

Views: 56

Answers (0)

Related Questions