qwr
qwr

Reputation: 10955

Finding null space of binary matrix in python

In factoring methods based on the quadratic sieve, finding the left null space of a binary matrix (values computed mod 2) is a crucial step. (This is also the null space of the transpose.) Does numpy or scipy have tools to do this quickly?

For reference, here is my current code:

# Row-reduce binary matrix
def binary_rr(m):
    rows, cols = m.shape
    l = 0
    for k in range(min(rows, cols)):
        print(k)
        if l >= cols: break
        # Swap with pivot if m[k,l] is 0
        if m[k,l] == 0:
            found_pivot = False
            while not found_pivot:
                if l >= cols: break
                for i in range(k+1, rows):
                    if m[i,l]:
                        m[[i,k]] = m[[k,i]]  # Swap rows
                        found_pivot = True
                        break

                if not found_pivot: l += 1

        if l >= cols: break  # No more rows

        # For rows below pivot, subtract row
        for i in range(k+1, rows):
            if m[i,l]: m[i] ^= m[k]

        l += 1

    return m

It is pretty much a straightforward implementation of Gaussian elimination, but since it's written in python it is very slow.

Upvotes: 4

Views: 1346

Answers (1)

oppressionslayer
oppressionslayer

Reputation: 7224

qwr, I found a very fast gaussian elimination routine that finishes so qiuckly that the slow point is the Quadratic Sieving or SIQS Sieving step. The gaussian elimination functions were taken from skollmans factorise.py at https://raw.githubusercontent.com/skollmann/PyFactorise/master/factorise.py

I'll soon be working on a SIQS/GNFS implementation from scratch, and hope to write something super quick for python with multithreading and possiblly cython. In the meantime, if you want something that compiles C (Alpertons ECM Engine) but uses python, you can use: https://github.com/oppressionslayer/primalitytest/ which requires you to cd into calculators directory and run make before importing p2ecm with from sfactorint import p2ecm. With that you can factorise 60 digit numbers in a few seconds.


# Requires sympy and numpy to be installed
# Adjust B and I accordingly. Set for 32 length number
# Usage:
# N=1009732533765251*1896182711927299
# factorise(N, 5000, 25000000) # Takes about 45-60 seconds on a newer computer
# N=1009732533765251*581120948477 
# Linear Algebra Step finishes in 1 second, if that                                                                                                                     
# N=factorise(N, 5000, 2500000) # Takes about 5 seconds on a newer computer                                                                                                                           
# #Out[1]: 581120948477



import math
import numpy as np
from sympy import isprime  
    
#
# siqs_ functions are the Gaussian Elimination routines right from
# skollmans factorise.py. It is the fastest Gaussian Elimination that i have
# found in python    
#
    
def siqs_factor_from_square(n, square_indices, smooth_relations):
    """Given one of the solutions returned by siqs_solve_matrix_opt,
    return the factor f determined by f = gcd(a - b, n), where
    a, b are calculated from the solution such that a*a = b*b (mod n).
    Return f, a factor of n (possibly a trivial one).
    """
    sqrt1, sqrt2 = siqs_calc_sqrts(square_indices, smooth_relations)
    assert (sqrt1 * sqrt1) % n == (sqrt2 * sqrt2) % n
    return math.gcd(abs(sqrt1 - sqrt2), n)
    
def siqs_find_factors(n, perfect_squares, smooth_relations):
    """Perform the last step of the Self-Initialising Quadratic Field.
    Given the solutions returned by siqs_solve_matrix_opt, attempt to
    identify a number of (not necessarily prime) factors of n, and
    return them.
    """
    factors = []
    rem = n
    non_prime_factors = set()
    prime_factors = set()
    for square_indices in perfect_squares:
        fact = siqs_factor_from_square(n, square_indices, smooth_relations)
        if fact != 1 and fact != rem:
            if isprime(fact):
                if fact not in prime_factors:
                    print ("SIQS: Prime factor found: %d" % fact)
                    prime_factors.add(fact)

                while rem % fact == 0:
                    factors.append(fact)
                    rem //= fact

                if rem == 1:
                    break
                if isprime(rem):
                    factors.append(rem)
                    rem = 1
                    break
            else:
                if fact not in non_prime_factors:
                    print ("SIQS: Non-prime factor found: %d" % fact)
                    non_prime_factors.add(fact)

    if rem != 1 and non_prime_factors:
        non_prime_factors.add(rem)
        for fact in sorted(siqs_find_more_factors_gcd(non_prime_factors)):
            while fact != 1 and rem % fact == 0:
                print ("SIQS: Prime factor found: %d" % fact)
                factors.append(fact)
                rem //= fact
            if rem == 1 or sfactorint_isprime(rem):
                break

    if rem != 1:
        factors.append(rem)
    return factors

def add_column_opt(M_opt, tgt, src):
    """For a matrix produced by siqs_build_matrix_opt, add the column
    src to the column target (mod 2).
    """
    M_opt[tgt] ^= M_opt[src]


def find_pivot_column_opt(M_opt, j):
    """For a matrix produced by siqs_build_matrix_opt, return the row of
    the first non-zero entry in column j, or None if no such row exists.
    """
    if M_opt[j] == 0:
        return None
    return lars_last_powers_of_two_trailing(M_opt[j] + 1)

def siqs_build_matrix_opt(M):
    """Convert the given matrix M of 0s and 1s into a list of numbers m
    that correspond to the columns of the matrix.
    The j-th number encodes the j-th column of matrix M in binary:
    The i-th bit of m[i] is equal to M[i][j].
    """
    m = len(M[0])
    cols_binary = [""] * m
    for mi in M:
        for j, mij in enumerate(mi):
            cols_binary[j] += "1" if mij else "0"
    return [int(cols_bin[::-1], 2) for cols_bin in cols_binary], len(M), m


def siqs_solve_matrix_opt(M_opt, n, m):
    """
    Perform the linear algebra step of the SIQS. Perform fast
    Gaussian elimination to determine pairs of perfect squares mod n.
    Use the optimisations described in [1].

    [1] Koç, Çetin K., and Sarath N. Arachchige. 'A Fast Algorithm for
        Gaussian Elimination over GF (2) and its Implementation on the
        GAPP.' Journal of Parallel and Distributed Computing 13.1
        (1991): 118-122.
    """
    row_is_marked = [False] * n
    pivots = [-1] * m
    for j in range(m):
        i = find_pivot_column_opt(M_opt, j)
        if i is not None:
            pivots[j] = i
            row_is_marked[i] = True
            for k in range(m):
                if k != j and (M_opt[k] >> i) & 1:  # test M[i][k] == 1
                    add_column_opt(M_opt, k, j)
    perf_squares = []
    for i in range(n):
        if not row_is_marked[i]:
            perfect_sq_indices = [i]
            for j in range(m):
                if (M_opt[j] >> i) & 1:  # test M[i][j] == 1
                    perfect_sq_indices.append(pivots[j])
            perf_squares.append(perfect_sq_indices)
    return perf_squares

def sqrt_int(N):
  Nsqrt = math.isqrt(N)
  assert Nsqrt * Nsqrt == N
  return Nsqrt

def siqs_calc_sqrts(square_indices, smooth_relations):
    """Given on of the solutions returned by siqs_solve_matrix_opt and
    the corresponding smooth relations, calculate the pair [a, b], such
    that a^2 = b^2 (mod n).
    """
    res = [1, 1]
    for idx in square_indices:
        res[0] *= smooth_relations[idx][0]
        res[1] *= smooth_relations[idx][1]
    res[1] = sqrt_int(res[1])
    return res
    

def quad_residue(a,n):
    l=1
    q=(n-1)//2
    x = q**l
    if x==0:
        return 1
        
    a =a%n
    z=1
    while x!= 0:
        if x%2==0:
            a=(a **2) % n
            x//= 2
        else:
            x-=1
            z=(z*a) % n

    return z
    
def STonelli(n, p):
    assert quad_residue(n, p) == 1, "not a square (mod p)"
    q = p - 1
    s = 0
    
    while q % 2 == 0:
        q //= 2
        s += 1
    if s == 1:
        r = pow(n, (p + 1) // 4, p)
        return r,p-r
    for z in range(2, p):
        #print(quad_residue(z, p))
        if p - 1 == quad_residue(z, p):
            break
    c = pow(z, q, p)
    r = pow(n, (q + 1) // 2, p)
    t = pow(n, q, p)
    m = s
    t2 = 0
    while (t - 1) % p != 0:
        t2 = (t * t) % p
        for i in range(1, m):
            if (t2 - 1) % p == 0:
                break
            t2 = (t2 * t2) % p
        b = pow(c, 1 << (m - i - 1), p)
        r = (r * b) % p
        c = (b * b) % p
        t = (t * c) % p
        m = i

    return (r,p-r)
    
def build_smooth_relations(smooth_base, root_base):
   smooth_relations = []
   
   for xx in range(len(smooth_base)):
     smooth_relations.append((root_base[xx], smooth_base[xx], xx))   
   
   return smooth_relations 
   

def strailing(N):
   return N>>lars_last_powers_of_two_trailing(N)


def lars_last_powers_of_two_trailing(N):
  p,y=1,2
  orign = N
  #if orign < 17: N = N%16
  N = N&15
  if N == 1: 
     if ((orign -1) & (orign -2)) == 0: return orign.bit_length()-1
     while orign&y == 0:
       p+=1
       y<<=1
     return p
  if N in [3, 7, 11, 15]: return 1
  if N in [5, 13]: return 2
  if N == 9: return 3
  return 0
   
def build_matrix(factor_base, smooth_base):
  factor_base = factor_base.copy()
  factor_base.insert(0, 2)
  
  sparse_matrix = []
  col = 0
  
  for xx in smooth_base:
    sparse_matrix.append([])
    for fx in factor_base:
      count = 0
      factor_found = False
      while xx % fx == 0:
        factor_found = True
        xx=xx//fx
        count+=1
      if count % 2 == 0:
        sparse_matrix[col].append(0)
        continue
      else:
        if factor_found == True:
          sparse_matrix[col].append(1)
        else:
          sparse_matrix[col].append(0)
    col+=1
                
  return np.transpose(sparse_matrix)  

def get_mod_congruence(root, N, withstats=False):
  r = root - N 
  if withstats==True:
    print(f"{root} ≡ {r} mod {N}")
  return r
  
def primes_sieve2(limit):
    a = np.ones(limit, dtype=bool)
    a[0] = a[1] = False

    for (i, isprime) in enumerate(a):
        if isprime:
            yield i
            for n in range(i*i, limit, i):
                a[n] = False  
                
def remove_singletons(XX):
  no_singletons = []

  for xx in XX:
    if len(xx) != 1:
      no_singletons.append(xx)
       
  return no_singletons
                
def fb_sm(N, B, I):

   factor_base, sieve_base, sieve_list, smooth_base, root_base = [], [], [], [], []

   primes = list(primes_sieve2(B))
   
   i,root=-1,math.isqrt(N)
   
   for x in primes[1:]:
       if quad_residue(N, x) == 1:
         factor_base.append(x)

   for x in range(I):
      xx = get_mod_congruence((root+x)**2, N)
      sieve_list.append(xx)
      if xx % 2 == 0:
        xx = strailing(xx+1) # using lars_last_modulus_powers_of_two(xx) bit trick
      sieve_base.append(xx)


   for p in factor_base:
       residues = STonelli(N, p)
     
       for r in residues:
          for i in range((r-root) % p, len(sieve_list), p):
            while sieve_base[i] % p == 0: 
              sieve_base[i] //= p

   for o in range(len(sieve_list)):
     # This is set to 350, which is only good for numbers 
     # of len < 32. Modify
     # to be more dynamic for larger numbers.
     if len(smooth_base) >= 350:
         break
     if sieve_base[o] == 1:
        smooth_base.append(sieve_list[o])
        root_base.append(root+o)
  
   return factor_base, smooth_base, root_base
   
def isSquare(hm):
  cr=math.isqrt(hm)
  if cr*cr == hm:
     return True
  return False 
  
def find_square(smooth_base):
   for x in smooth_base: 
      if isSquare(x): 
          return (True, smooth_base.index(x)) 
      else:
          return (False, -1)

t_matrix=[]   
primes=list(primes_sieve2(1000000))


def factorise(N, B=10000, I=10000000):

   global primes, t_matrix
   
   if isprime(N):
     return N
   
   for xx in primes:
      if N%xx == 0:
         return xx
   
   factor_base, smooth_base, root_base = fb_sm(N,B,I)
   
   issquare, t_matrix = find_square(smooth_base)
   if issquare == True:
      return math.gcd(math.isqrt(smooth_base[t_matrix])+get_mod_congruence(root_base[t_matrix], N), N)
      
   t_matrix = build_matrix(factor_base, smooth_base)
   smooth_relations = build_smooth_relations(smooth_base, root_base)
   M_opt, M_n, M_m = siqs_build_matrix_opt(np.transpose(t_matrix))
   perfect_squares = remove_singletons(siqs_solve_matrix_opt(M_opt, M_n, M_m))
   factors = siqs_find_factors(N, perfect_squares, smooth_relations)

   return factors

Upvotes: 1

Related Questions