David Simka
David Simka

Reputation: 566

Recursion that branches running out of memory

I have a programming assignment that goes like this:

You are given three numbers a, b, and c. (1 ≤ а, b, c ≤ 10^18) Each time you have two choises, either add b to a (a+=b), or add a to b (b+=a). Write a program that will print out YES or NO depending on whether you can get to c by adding a and b to each other.

I've tried solving this problem using recursion that branches to two branches every time where one branch stores a+b, b and the other branch stores a, b+a. In every recursive call, the function checks the values of a and b, and if they are equal to c the search stops and the function prints YES. The recursion stops when either a or b have a value greater than c.

Here's how the branching works: enter image description here

And here's the code in C:

#include <stdio.h>
#include <stdlib.h>

void tree(long long int a, long long int b, long long int c){
    if(a==c || b==c){
        printf("YES");
        exit(0);
    }
    else if(a<c && b<c){
        tree(a, b+a, c);
        tree(a+b, b, c);
    }
}

int main(){
    long long int a, b, c;
    scanf("%I64d", &a);
    scanf("%I64d", &b);
    scanf("%I64d", &c);

    tree(a, b, c);

    printf("NO");

    return 0;
}

Now, this program works for small numbers, but since a b and c can be any 64-bit number, the tree can branch itself a few billion times, and the program runs out of memory and crashes.

My question is: Is there any way i can improve my code, or use any other way (other then recursion) to solve this?

Upvotes: 4

Views: 353

Answers (2)

JS1
JS1

Reputation: 4767

OK I'll have to admit that this turned out to be a fascinating question. I really thought that there should be a quick way of finding out the answer but the more I looked at the problem, the more complex it became. For example, if you zigzag down the tree, alternating a+=b with b+=a, you are essentially creating the fibonacci sequence (imagine a=2 and b=3 to start with). Which means that if you could find the answer quickly, you could somehow use a similar program to answer "is c a fibonacci number"?

So I never came up with anything better than searching the binary tree. But I did come up with a way to search the binary tree without running out of memory. The key trick in my algorithm is that at every node you need to search two child nodes. But you don't need to recurse down both paths. You only need to recurse down one path, and if that fails, you can iterate to the other child. When recursing, you should always pick the path where the smaller number changes. This guarantees that you are doubling the minimum element on each recursion level, which guarantees that you will only recurse 64 times max before your minimum element will exceed 2^64.

So I wrote the program and ran it, and it worked just fine. That is until I entered a very large number for c. At that point, it didn't finish. I found from testing that the algorithm appears to have an O(N^2) running time, where N = c. Here are some sample running times (all on a desktop running 64-bit Windows).

Inputs                              Time in minutes
------                              ---------------
a=2   b=3   c=10000000000  (10^10):  0:20
a=2   b=3   c=100000000000 (10^11): 13:42
a=2   b=3   c=100000000001        :  2:21 (randomly found the answer quickly)
a=2   b=3   c=100000000002        : 16:36
a=150 b=207 c=10000000     (10^7) :  0:08 (no solution)
a=150 b=207 c=20000000            :  0:31 (no solution)
a=150 b=207 c=40000000            :  2:05 (no solution)
a=150 b=207 c=100000000    (10^8) : 12:48 (no solution)

Here is my code:

// Given three numbers: a, b, c.
//
// At each step, either do: a += b, or b += a.
// Can you make either a or b equal to c?
#include <stdio.h>
#include <stdint.h>
#include <stdlib.h>

static int solve(uint64_t a, uint64_t b, uint64_t c);

int main(int argc, char *argv[])
{
    uint64_t a = 0, b = 0, c = 0;

    if (argc < 4) {
        printf("Usage: %s a b c\n", argv[0]);
        exit(0);
    }
    a = strtoull(argv[1], NULL, 0);
    b = strtoull(argv[2], NULL, 0);
    c = strtoull(argv[3], NULL, 0);

    // Note, by checking to see if a or b are solutions here, solve() can
    // be made simpler by only checking a + b == c.  That speeds up solve().
    if (a == c || b == c || solve(a, b, c))
        printf("There is a solution\n");
    else
        printf("There is NO solution\n");
    return 0;
}

int solve(uint64_t a, uint64_t b, uint64_t c)
{
    do {
        uint64_t sum = a + b;
        // Check for past solution.
        if (sum > c)
            return 0;
        // Check for solution.
        if (sum == c)
            return 1;
        // The algorithm is to search both branches (a += b and b += a).
        // But first we search the branch where add the higher number to the
        // lower number, because that branch will be guaranteed to double the
        // lower number, meaning we will not recurse more than 64 times.  Then
        // if that doesn't work out, we iterate to the other branch.
        if (a < b) {
            // Try a += b recursively.
            if (solve(sum, b, c))
                return 1;
            // Failing that, try b += a.
            b = sum;
        } else {
            // Try b += a recursively.
            if (solve(a, sum, c))
                return 1;
            // Failing that, try a += b.
            a = sum;
        }
    } while(1);
}

Edit: I optimized the above program by removing recursion, reordering the arguments so that a is always less than b at every step, and some more tricks. It runs about 50% faster than before. You can find the optimized program here.

Upvotes: 4

Weather Vane
Weather Vane

Reputation: 34585

Based on comment from @Oliver Charlesworth, this is an iterative not recursive solution so it won't solve the homework. But it's pretty simple, I step through b because it is larger than a (although that is not entirely clear from the OP) hence fewer iterations required.

#include <stdio.h>

int main(){
    unsigned long long int a, b, c, bb;
    scanf("%I64u", &a);
    scanf("%I64u", &b);
    scanf("%I64u", &c);

    if (a >= 1 && a < b && b < c) {
        for (bb=b; bb<c; bb+=b) {
            if ((c - bb) % a == 0) {
                printf ("YES\n");
                return 0;
            }
        }
    }
    printf("NO\n");
    return 0;
}

Upvotes: 1

Related Questions