oblitum
oblitum

Reputation: 12016

Can this be tail call optimized? If so, what's the special reason for it not happen?

I've checked assembly output at many optimization levels with both gcc 4.8.1 and clang 3.4.190255, no tail call optimization for this kind of code.

Any special reason why collatz_aux doesn't get a tail call optimization?

#include <vector>
#include <cassert>

using namespace std;

vector<unsigned> concat(vector<unsigned> v, unsigned n) {
    v.push_back(n);
    return v;
}

vector<unsigned> collatz_aux(unsigned n, vector<unsigned> result) {
    return n == 1
        ? result
        : n % 2 == 0
            ? collatz_aux(n / 2, concat(move(result), n))
            : collatz_aux(3 * n + 1, concat(move(result), n));
}

vector<unsigned> collatz_vec(unsigned n) {
    assert(n != 0);
    return collatz_aux(n, {});
}

int main() {
    return collatz_vec(10).size();
}

Upvotes: 8

Views: 340

Answers (3)

oblitum
oblitum

Reputation: 12016

Just for reference, I tweaked the recursive version, to get tail recursion, to this:

#include <vector>
#include <cassert>

using namespace std;

template<class container>
container &&collatz_aux(unsigned n, container &&result) {
    static auto concat = [](container &&c, unsigned n) -> container &&{
        c.push_back(n);
        return forward<container>(c);
    };

    return n == 1
        ? forward<container>(result)
        : n % 2 == 0
            ? collatz_aux(n / 2, concat(forward<container>(result), n))
            : collatz_aux(3 * n + 1, concat(forward<container>(result), n));
}

vector<unsigned> collatz_vec(unsigned n) {
    assert(n != 0);
    return collatz_aux(n, vector<unsigned>{});
}

int main() {
    return collatz_vec(10).size();
}

Upvotes: 2

Ben
Ben

Reputation: 35663

You shouldn't be relying on tail-call for this. I would think it unlikely that the optimiser is going to spot that both recursive calls can be tail-optimised.

Here's a non-recursive version.

vector<unsigned> collatz_aux(unsigned n, vector<unsigned> result) {
  while(true){
    if(n == 1) return result;
    result = concat(move(result), n);
    if(n % 2 == 0)
    {
      n=n / 2;
    }else{
      n= 3 * n + 1;
    }
  }
}

Upvotes: 1

Simple
Simple

Reputation: 14420

The destructor for the vector<unsigned> parameter needs to be called after the return.

Upvotes: 12

Related Questions