sharp_c-tudent
sharp_c-tudent

Reputation: 463

KEM/DEM using cryptosystem NTRU on Sage

First of all I must say my knowledge with using Sage math is really very limited, but I really want to improve and to be able to solve these problems I'm having. I have been asked to implement the following:

Use an Sage implementation of the cryptosystem NTRU and the library “cryptography” to build a KEM/DEM system with security of 128 bits, generated key of 128 bits and, on the DEM phase, use the cipher AES of 128 bits.

While trying to solve I came across an implementation of NTRU-Prime in sage and wanted to use it to solve this problem:

My Attempt:

p = 739; q = 9829; t = 204
Zx.<x> = ZZ[]; R.<xp> = Zx.quotient(x^p-x-1)
Fq = GF(q); Fqx.<xq> = Fq[]; Rq.<xqp> = Fqx.quotient(x^p-x-1)
F3 = GF(3); F3x.<x3> = F3[]; R3.<x3p> = F3x.quotient(x^p-x-1)

import itertools

def concat(lists):
    return list(itertools.chain.from_iterable(lists))

def nicelift(u):
    return lift(u + q//2) - q//2

def nicemod3(u): # r in {0,1,-1} with u-r in {...,-3,0,3,...}
    return u - 3*round(u/3)

def int2str(u,bytes):
    return ''.join([chr((u//256^i)%256) for i in range(bytes)])

def str2int(s):
     return sum([ord(s[i])*256^i for i in range(len(s))])

def encodeZx(m): # assumes coefficients in range {-1,0,1,2}
    m = [m[i]+1 for i in range(p)] + [0]*(-p % 4)
    return ''.join([int2str(m[i]+m[i+1]*4+m[i+2]*16+m[i+3]*64,1) for i in range(0,len(m),4)])

def decodeZx(mstr):
    m = [str2int(mstr[i:i+1]) for i in range(len(mstr))]
    m = concat([[m[i]%4,(m[i]//4)%4,(m[i]//16)%4,m[i]//64] for i in range(len(m))])
    return Zx([m[i]-1 for i in range(p)])

def encodeRq(h, nBits = 9856):
    h = [lift(h[i]) for i in range(p)] + [0]*(-p % 3)
    h = ''.join([int2str(h[i]+h[i+1]*10240+h[i+2]*10240^2,5) for i in range(0,len(h),3)])
    return h[0:(nBits/8)]

def decodeRq(hstr):
    h = [str2int(hstr[i:i+5]) for i in range(0,len(hstr),5)]
    h = concat([[h[i]%10240,(h[i]//10240)%10240,h[i]//10240^2] for i in range(len(h))])
    if max(h) >= q: raise Exception("pk out of range")
    return Rq(h)

def encoderoundedRq(c,nBits):
    c = [1638 + nicelift(c[i]/3) for i in range(p)] + [0]*(-p % 2)
    c = ''.join([int2str(c[i]+c[i+1]*4096,3) for i in range(0,len(c),2)])
    return c[0:1109]

def decoderoundedRq(cstr):
    c = [str2int(cstr[i:i+3]) for i in range(0,len(cstr),3)]
    c = concat([[c[i]%4096,c[i]//4096] for i in range(len(c))])
    if max(c) > 3276: raise Exception("c out of range")
    return 3*Rq([c[i]-1638 for i in range(p)])

def randomR(): # R element with 2t coeffs +-1
    L = [2*randrange(2^31) for i in range(2*t)]
    L += [4*randrange(2^30)+1 for i in range(p-2*t)]
    L.sort()
    L = [(L[i]%4)-1 for i in range(p)]
    return Zx(L)

def keygen():
    while True:
        g = Zx([randrange(3)-1 for i in range(p)])
        if R3(g).is_unit(): break
    f = randomR()
    h = Rq(g)/(3*Rq(f))
    pk = encodeRq(h)
    return pk,encodeZx(f) + encodeZx(R(lift(1/R3(g)))) + pk


import hashlib

def hash(s): h = hashlib.sha512(); h.update(s); return h.digest()

def encapsulate(pk):
    h = decodeRq(pk)
    r = randomR()
    hr = h * Rq(r)
    m = Zx([-nicemod3(nicelift(hr[i])) for i in range(p)])
    c = Rq(m) + hr
    fullkey = hash(encodeZx(r))
    return fullkey[:32] + encoderoundedRq(c,128),fullkey[32:]

def decapsulate(cstr,sk):
    f,ginv,h = decodeZx(sk[:185]),decodeZx(sk[185:370]),decodeRq(sk[370:])
    confirm,c = cstr[:32],decoderoundedRq(cstr[32:])
    f3mgr = Rq(3*f) * c
    f3mgr = [nicelift(f3mgr[i]) for i in range(p)]
    r = R3(ginv) * R3(f3mgr)
    r = Zx([nicemod3(lift(r[i])) for i in range(p)])
    hr = h * Rq(r)
    m = Zx([-nicemod3(nicelift(hr[i])) for i in range(p)])
    checkc = Rq(m) + hr
    fullkey = hash(encodeZx(r))
    if sum([r[i]==0 for i in range(p)]) != p-2*t: return False
    if checkc != c: return False
    if fullkey[:32] != confirm: return False
    return fullkey[32:]


print("Exe 2")
print("")
pk,sk =  keygen()
c,k   =  encapsulate(pk)
k     == decapsulate(c,sk)
print("")
print("{:d} bytes in public key that is {:d} bits".format(len(pk),len(pk)*8))
print("{:d} bytes in secret key that is {:d} bits".format(len(sk),len(sk)*8))
print("{:d} bytes in ciphertext that is {:d} bits".format(len(c),len(c)*8))
print("{:d} bytes in shared secret that is {:d} bits".format(len(k),len(k)*8))

Now I know I can use this to get to the solution of the question mentioned above. I assume the key mentioned in the question is the private key because that is the one we generate (am I right?) so I know I have to edit this function:

def keygen():
    while True:
        g = Zx([randrange(3)-1 for i in range(p)])
        if R3(g).is_unit(): break
    f = randomR()
    h = Rq(g)/(3*Rq(f))
    pk = encodeRq(h) #Encode private key with 128 bits
    return pk,encodeZx(f) + encodeZx(R(lift(1/R3(g)))) + pk

In order to do this I tried editing the function encodeRq used on this one to encode 128 bits only but that brought up lots of compiling errors that I just couldn't understand. But at least am I right to assume it's here where I have to set my key to be generated with 128 bits?

I believe the KEM mentioned in the question is handled with the function encapsulate and believe that I don't have to change anything there (am I right?)

The biggest problem is really the DEM phase which I believe is being implemented on the function decapsulate (am I right?) but how should I change this to use AES? How do I do it on sage? any library I should know about?

I am a bit lost here and just want to know if my assumptions are correct and to be indicated on the right path. Thanks for any answer in advance.

Upvotes: 6

Views: 542

Answers (0)

Related Questions