user2649681
user2649681

Reputation: 848

Schonhage-Strassen Multiplication implementation error

I am attempting to implement the Schonhage-Strassen multiplication algorithm using NTT, and am running into a problem where the final resulting vector is not actually equal to what it should be.

For two input vectors a, and b, each consisting of N "digits" of K bits (the final N/2 entries of each set to 0), each, given a modulus M = 2^(2*K)+1, a root of unity w = N^(4*K-1) | w^N = 1 mod M, a modular inverse of this value wi | wi*w = 1 mod M, and u | u*N = 1 mod M, the following python code is used to (attempt to) multiply these vectors using the Schonhage-Strassen algorithm:

#a and b are lists of length N, representing large integers
A = [ sum([ (a[i]*pow(w,i*j,M))%M for i in range(N)]) for j in range(N)] #NTT of a
B = [ sum([ (b[i]*pow(w,i*j,M))%M for i in range(N)]) for j in range(N)] #NTT of b
C = [ (A[i]*B[i])%M for i in range(N)] #A * B multiplied pointwise
c = [ sum([ (C[i]*pow(wi,i*j,M))%M for i in range(N)]) for j in range(N)] #intermediate step in INTT of C
ci = [ (i*u)%M for i in c] #INTT of C, should be product of a and b

In theory, taking the NTT of a and b, multiplying pointwise, then taking the INTT of the result should give the product, if I am not mistaken, and I have tested these methods for NTT and INTT to confirm that they are inverses of each other. However, the final resulting vector ci, rather than being equal to the product of a and b, is the product where each element is taken modulo M, giving an incorrect result for the product.

For example, running a test with N=K=8, and random vectors for a, b, gives the following:

M = 2^(2*8)+1 = 65537
w = 16, wi = 61441
u = 57345
a = [212, 251, 84, 186, 0, 0, 0, 0] (3126131668 as an integer)
b = [180, 27, 234, 225, 0, 0, 0, 0] (3790216116)
NTT(a) = [733, 66681, 147842, 92262, 130933, 107825, 114562, 127302]
NTT(b) = [666, 64598, 80332, 54468, 131236, 186644, 181708, 88232]
Pointwise product of above two lines mod M = [29419, 39913, 25015, 14993, 42695, 49488, 52438, 51319]
INTT of above line (i.e. result) = [38160, 50904, 5968, 11108, 15616, 62424, 41850, 0] (11848430946168040720)
Actual product of a x b = [38160, 50904, 71505, 142182, 81153, 62424, 41850, 0] (11848714628791561488)

In this example, and in pretty much every time I try it, the elements of the actual product and the result of my algorithm are the same for several elements near the beginning and end of the vector, but towards the middle they deviate. As I mentioned above, the elements of ci are each equal to the elements of a*b modulo M. I must be misunderstanding something about this algorithm, though I'm not entirely sure what. Am I using the wrong modulus somewhere?

Upvotes: 0

Views: 424

Answers (1)

Spektre
Spektre

Reputation: 51873

Beware number theory and NTT are not my field of expertise so read with prejudice but I did successfully implement NTT in C++ on my own and used it for bignum multiplications (bigint, bigfloatingpoint, bigfixedpoint) so Here some insigtst of mine. I strongly suggest you read both 2 of mine related QAs first:

so you can compare your results/code/constants with mine. However I evolved my NTT to use single hardcoded prime (biggest rooth of unity that fits in 32bit value).

Now what can be wrong with your code. I do not code in python but I do not see NTT code in your question. Anyway from what I can see:

  1. check your root or unity

    In your question you mention condition:

    w^N = 1 mod M
    

    but that is far from enough. See the first link above it describes all conditions that must be met (with code that finds and check it). I am not sure your parameters complies all needed conditions and you just forget or miss wrote those or not so check it. IIRC I struggled with those conditions too as at the time I coded this there where very little NTT info at my disposal and most of them where incomplete or wrong...

  2. You are not using modular arithmetics !!!

    I assume Your prime is M (in mine terminology its p) so all the subresults must be smaller than M which is clearly not true in your example:

    M = 65537
    NTT(a) = [733, 66681, 147842, 92262, 130933, 107825, 114562, 127302]
    NTT(b) = [666, 64598, 80332, 54468, 131236, 186644, 181708, 88232]
    

    as you can see only first element of both NTTs are valid all others are bigger than M that is wrong !!!

  3. beware of overflows

    Your M is really small ~16bit in comparison to your input values which look ~8bit that can overflow really fast invalidating your NTT results too.

    Here a quote from my second link I found out the hard and empirical way:

    To avoid overflows for big datasets, limit input numbers to p/4 bits. Where p is number of bits per NTT element so for this 32 bit version use max (32 bit/4 -> 8 bit) input values.

    so in your case you should process 16/4 = 4bit chunks instead of 8 bit or use much bigger M for example like mine 0xC0000001 which is ~32bit.

    This explains your observationa that first elements of product are good and then not ... realize if you multiply 2 8bit numbers you got 16 bit ... now realize you are doing more recursive additions of the multiplicated subresults so it will get above 16 bit M very soon in your case right in the second value ...

So in summary you are not using modular arithemics, have too small prime and/or process too big chunks of data and possibly have also wrong prime selected.

Upvotes: 2

Related Questions