Dino 1729
Dino 1729

Reputation: 37

Algorithm for expressing given number as a sum of two squares

My problem is as follows:

I'm given a natural number n and I want to find all natural numbers x and y such that

n = x² + y²

Since this is addition order does not matter so I count (x,y) and (y,x) as one solution.

My initial algorithm is to assume that y>x, for all x compute y²=n-x² and check if y is a natural number using binary search on y².


for(x=1;2*x*x<n;x++)
{
     y_squared=n-x*x;
     if(isSquare(y_squared)==false)
           continue;

     //rest of the code
}

Is there any improvement for my algorithm? I already checked if n can have solutions using two squares theorem, but I want to know how many there are.

My algorithm is O(sqrt(n) * log(n) )

Thanks in advance.

Upvotes: 1

Views: 1214

Answers (2)

Kelly Bundy
Kelly Bundy

Reputation: 27589

A variation of Paul's, keeping track of the sum of squares and adjusting it just with additions/subtractions:

Pseudo-code: (evaluate x++ + x and y-- + y left-to-right, or do it like in the Python code below)

x = 0
y = floor(sqrt(n))
sum = y * y

while x <= y
    if sum < n
        sum += x++ + x
    else if sum > n
        sum -= y-- + y
    else
        print(x, y)
        sum += 2 * (++x - y--)

Java:

  static void allSolutions(int n) {
    int x = 0, y = (int) Math.sqrt(n), sum = y * y;
    while (x <= y) {
      if (sum < n) {
        sum += x++ + x;
      } else if (sum > n) {
        sum -= y-- + y;
      } else {
        System.out.println(x + " " + y);
        sum += 2 * (++x - y--);
      }
    }
  }

Python:

from math import isqrt

def all_solutions(n):
    x = 0
    y = isqrt(n)
    sum = y ** 2

    while x <= y:
        if sum < n:
            x += 1
            sum += 2 * x - 1
        elif sum > n:
            sum -= 2 * y - 1
            y -= 1
        else:
            # found a match
            print(x, y)
            x += 1
            sum += 2 * (x - y)
            y -= 1

Demo:

>>> all_solutions(5525)
7 74
14 73
22 71
25 70
41 62
50 55

Upvotes: 1

user4668606
user4668606

Reputation:

You can reduce this to O(sqrt(n)) this way:

all_solutions(n):
    x = 0
    y = floor(sqrt(n))

    while x <= y
        if x * x + y * y < n
            x++
        else if x * x + y * y > n
            y--
        else
            // found a match
            print(x, y)
            x++
            y--

This algorithm will find and print all possible solutions and will always terminate for x <= sqrt(n / 2) and y >= sqrt(n / 2), leading to at most sqrt(n / 2) + (sqrt(n) - sqrt(n / 2)) = sqrt(n) iterations being performed.

Upvotes: 1

Related Questions