Robin Yu
Robin Yu

Reputation: 339

Computing Nth triangular number that is also a square number

This problem came out in a practice contest:

Compute the Nth triangular number that is also a square number, modulo 10006699. (1 ≤ N ≤ 10^18) There are up to 10^5 test cases.

I found out that I can easily compute it with the recurrence relation Ti = 6Ti-1 - Ti-2 + 2, with T0 = 0 and T1 = 1.

I'm using matrix exponentiation for approximately O(log N) performance per test case, but it's apparently too slow, since there are 10^5 test cases. In fact, this code is too slow even when the constraints are only (1 ≤ N ≤ 10^6), where I could just do an O(N) preprocessing and O(1) query.

Should I change my approach to the problem, or should I just optimize some parts of the code?

#include <ios>
#include <iostream>
#include <vector>
#define MOD 10006699

/*
Transformation Matrix:

 0 1 0   t[i]     t[i+1]
-1 6 1 * t[i+1] = t[i+2]
 0 0 1     2        2
*/

std::vector<std::vector<long long int> > multi(std::vector<std::vector<long long int> > a, std::vector<std::vector<long long int> > b)
{
    std::vector<std::vector<long long int> > c(3, std::vector<long long int>(3));
    for (int i = 0; i < 3; i++)
    {
        for (int j = 0; j < 3; j++)
        {
            for (int k = 0; k < 3; k++)
            {
                c[i][j] += (a[i][k] * b[k][j]) % MOD;
                c[i][j] %= MOD;
            }
        }
    }
    return c;
}

std::vector<std::vector<long long int> > power(std::vector<std::vector<long long int> > vec, long long int p)
{
    if (p == 1) return vec;
    else if (p % 2 == 1) return multi(vec, power(vec, p-1));
    else
    {
        std::vector<std::vector<long long int> > x = power(vec, p/2);
        return multi(x, x);
    }
}

int main()
{
    std::ios_base::sync_with_stdio(false);
    long long int n;
    while (std::cin >> n)
    {
        if (n == 0) break;
        else
        {
            std::vector<std::vector<long long int> > trans;
            long long int ans;
            trans.resize(3);

            trans[0].push_back(0);  
            trans[0].push_back(1);
            trans[0].push_back(0);
            trans[1].push_back(-1);
            trans[1].push_back(6);
            trans[1].push_back(1);
            trans[2].push_back(0);
            trans[2].push_back(0);
            trans[2].push_back(1);

            trans = power(trans, n);

            ans = (trans[0][1]%MOD + (2*trans[0][2])%MOD)%MOD;

            if (ans < 0) ans += MOD;

            std::cout << ans << std::endl;
        }
    }
}

Upvotes: 3

Views: 431

Answers (1)

asimes
asimes

Reputation: 5894

Note: I removed my old answer, this is more useful

It seems unlikely that you will create a better asymptotic algorithm than O(log N) for the problem. However, there are modifications to that can be performed to your current code which will not improve the asymptotic time but will improve the performance

Below is a modification of your code that produces the same answer:

#include <ctime>
#include <ios>
#include <iostream>
#include <vector>
#define MOD 10006699

void power(std::vector<std::vector<long long int> >& vec, long long int p)
{
    if (p == 1)
        return;

    else if (p & 1)
    {
        std::vector<std::vector<long long int> > copy1 = vec;
        power(copy1, p-1);

        std::vector<std::vector<long long int> > copy2(3, std::vector<long long int>(3));
        for (int i = 0; i < 3; i++)
            for (int j = 0; j < 3; j++)
            {
                for (int k = 0; k < 3; k++)
                    copy2[i][j] += (vec[i][k] * copy1[k][j]) % MOD;
                copy2[i][j] %= MOD;
            }
        vec = copy2;

        return;
    }

    else
    {
        power(vec, p/2);

        std::vector<std::vector<long long int> > copy(3, std::vector<long long int>(3));
        for (int i = 0; i < 3; i++)
            for (int j = 0; j < 3; j++)
            {
                for (int k = 0; k < 3; k++)
                    copy[i][j] += (vec[i][k] * vec[k][j]) % MOD;
                copy[i][j] %= MOD;
            }
        vec = copy;

        return;
    }
}

int main()
{
    std::ios_base::sync_with_stdio(false);
    long long int n;
    while (std::cin >> n)
    {
        std::clock_t start = std::clock();
        if (n == 0) break;

        std::vector<std::vector<long long int> > trans;
        long long int ans;
        trans.resize(3);

        trans[0].push_back(0);  
        trans[0].push_back(1);
        trans[0].push_back(0);
        trans[1].push_back(-1);
        trans[1].push_back(6);
        trans[1].push_back(1);
        trans[2].push_back(0);
        trans[2].push_back(0);
        trans[2].push_back(1);

        power(trans, n);

        ans = (trans[0][1]%MOD + (2*trans[0][2])%MOD)%MOD;
        if (ans < 0) ans += MOD;
        std::cout << "Answer: " << ans << std::endl;

        std::cout << "Time: " << (std::clock() - start) / (double)(CLOCKS_PER_SEC / 1000) << " ms" << std::endl;
    }
}

The differences mainly are:

  • A code motion of c[i][j] %= MOD; to be outside the k loop
  • Passing vectors by reference
  • Less function calls

If I place the same timing code in the while loop of your main as I have in my code, name your file "before.cpp", name my file "after.cpp", and run each 10 times in a row with full optimizations then these are my results:

Alexanders-MBP:Desktop alexandersimes$ g++ before.cpp -O3 -o before
Alexanders-MBP:Desktop alexandersimes$ ./before 
1000000000000000000
Answer: 6635296
Time: 0.708 ms
1000000000000000000
Answer: 6635296
Time: 0.542 ms
1000000000000000000
Answer: 6635296
Time: 0.688 ms
1000000000000000000
Answer: 6635296
Time: 0.634 ms
1000000000000000000
Answer: 6635296
Time: 0.626 ms
1000000000000000000
Answer: 6635296
Time: 0.629 ms
1000000000000000000
Answer: 6635296
Time: 0.629 ms
1000000000000000000
Answer: 6635296
Time: 0.629 ms
1000000000000000000
Answer: 6635296
Time: 0.632 ms
1000000000000000000
Answer: 6635296
Time: 0.695 ms

Alexanders-MBP:Desktop alexandersimes$ g++ after.cpp -O3 -o after
Alexanders-MBP:Desktop alexandersimes$ ./after 
1000000000000000000
Answer: 6635296
Time: 0.283 ms
1000000000000000000
Answer: 6635296
Time: 0.287 ms
1000000000000000000
Answer: 6635296
Time: 0.27 ms
1000000000000000000
Answer: 6635296
Time: 0.27 ms
1000000000000000000
Answer: 6635296
Time: 0.266 ms
1000000000000000000
Answer: 6635296
Time: 0.265 ms
1000000000000000000
Answer: 6635296
Time: 0.266 ms
1000000000000000000
Answer: 6635296
Time: 0.267 ms
1000000000000000000
Answer: 6635296
Time: 0.21 ms
1000000000000000000
Answer: 6635296
Time: 0.208 ms

Upvotes: 1

Related Questions