Ξένη Γήινος
Ξένη Γήινος

Reputation: 3064

Efficient ways to compute nth root in C++ without using cmath

How to efficiently compute nth root of a number to at least 12 correct decimal places without using cmath and such?

I have tried to solve it myself. My idea is to find an approximation and use Newton's method to make the approximation more accurate.

I have implemented 2 methods, one using binary search, the other one is based on the fast inverse square root algorithm.

#include <array>
#include <chrono>
#include <cmath>
#include <iostream>
#include <vector>

using std::chrono::steady_clock;
using std::chrono::duration;
using std::cout;
using std::vector;
float r = 0.0;

inline float power(float base, int exp) {
    if (not exp) {
        return 1.0;
    }
    if (exp < 0) {
        base = 1 / base;
        exp = -exp;
    }
    float p = 1.0;
    while (exp > 1) {
        if (exp % 2) {
            p = base * p;
        }
        base *= base;
        exp /= 2;
    }
    return base * p;
}

inline float nth_root(float base, int n) {
    float lo, hi, x, p, r, v;
    int n1 = n - 1;
    lo = 0;
    hi = base;
    for (int i = 0; i < 12; i++) {
        x = (lo + hi) / 2;
        p = power(x, n1);
        v = x * p - base;
        r = n * p;
        if (v <= 0) {
            lo = x;
        }
        else {
            hi = x;
        }
    }
    x = (lo + hi) / 2;
    r = 1.0 / n;
    for (int i = 0; i < 12; i++) {
        x = r * (n * x + base / power(x, n));
    }
    return x;
}

inline float fast_nth_root(float base, int n)
{
   uint32_t i = std::bit_cast<uint32_t>(base);
   float rn = 1.0 / n;
   i = 0x3F7A3BEA * rn * (n + 1) - i * rn;
   float x = std::bit_cast<float>(i);
   for (int j = 0; j < 6; j++) {
       x = x * (n + 1 - base * power(x, n)) * rn;
   }
   return 1.0 / x;
}

int main()
{
    vector<float> bases(256);
    vector<int> ns(256);
    float r256 = 1.0 / 256;
    for (int i = 0; i < 256; i++) {
        bases[i] = 1.0 + rand() % 16384 + (rand() % 256) * r256;
        ns[i] = 2 + rand() % 30;
    }
    auto start = steady_clock::now();
    for (int64_t i = 0; i < 1048576; i++) {
        r += nth_root(bases[i % 256], ns[i % 256]);
    }
    auto end = steady_clock::now();
    duration<double, std::nano> time = end - start;
    cout << "nth_root: " << time.count() / 1048576 << " nanoseconds\n";
    start = steady_clock::now();
    for (int64_t i = 0; i < 1048576; i++) {
        r += pow(bases[i % 256], 1.0 / ns[i % 256]);
    }
    end = steady_clock::now();
    time = end - start;
    cout << "pow: " << time.count() / 1048576 << " nanoseconds\n";
    start = steady_clock::now();
    for (int64_t i = 0; i < 1048576; i++) {
        r += fast_nth_root(bases[i % 256], ns[i % 256]);
    }
    end = steady_clock::now();
    time = end - start;
    cout << "fast_nth_root: " << time.count() / 1048576 << " nanoseconds\n";
}

Compiled with:

g++.exe -Wall -fexceptions -fomit-frame-pointer -fexpensive-optimizations -flto -O3 -m64 --std=c++20 -march=native -ffast-math  -c D:\MyScript\CodeBlocks\testapp\main.cpp -o obj\Release\main.o
g++.exe  -o bin\Release\testapp.exe obj\Release\main.o  -O3 -flto -s -static-libstdc++ -static-libgcc -static -m64  
PS C:\Users\Xeni> D:\MyScript\CodeBlocks\testapp\bin\Release\testapp.exe
nth_root: 318.12 nanoseconds
pow: 104.77 nanoseconds
fast_nth_root: 53.4222 nanoseconds

The first method is very slow, as expected, but though the second method is faster than the library code it may be not as accurate.

According to my tests in Python:

import random, struct

def root(num, p, lim, lin):
    lo = 0
    hi = num
    for _ in range(lim):
        x = (lo + hi) / 2
        po = x ** (p - 1)
        v = x * po - num
        r = p * po
        if v <= 0:
            lo = x
        else:
            hi = x

    x = (lo + hi) / 2
    r = 1 / p
    p -= 1
    for _ in range(lin):
        x = r * (p * x + num / x ** p)
    
    return x

stats = {}
for _ in range(64):
    n = random.randrange(1, 16384)
    p = random.randrange(2, 32)
    x = n ** (1 / p)
    i = 2
    j = 2
    b = 0
    while abs((y := root(n, p, i, j)) - x) > 1e-13:
        if i < 16 and b or i >= 16:
            j += 1
        else:
            i += 1
        b = not b

    stats[(n, p)] = (i, j, x, y)

def fast_nth_root(x: float, n: int, lim: int) -> float:
    t = int.from_bytes(struct.pack('>f', x), 'big')
    t = round(0x3F7A3BEA / n * (n + 1) - t / n)
    t = struct.unpack('>f', struct.pack('>i', t))[0]
    for _ in range(lim):
        t = t * (n + 1 - x * t ** n) / n
    
    return 1 / t


stats1 = {}
for _ in range(64):
    n = random.randrange(1, 16384)
    p = random.randrange(2, 32)
    x = n ** (1 / p)
    i = 2
    while abs((y := fast_nth_root(n, p, i)) - x) > 1e-13:
        i += 1

    stats1[(n, p)] = (i, x, y)
In [1434]: stats
Out[1434]:
{(7162, 13): (11, 10, 1.979434344458145, 1.979434344458145),
 (15510, 2): (5, 5, 124.53915047084591, 124.53915047084595),
 (3054, 25): (10, 9, 1.3784618821753079, 1.3784618821753076),
 (1601, 25): (9, 8, 1.3433081611539914, 1.3433081611540099),
 (3522, 21): (10, 9, 1.475348878994169, 1.475348878994169),
 (15107, 14): (12, 11, 1.9884410975573956, 1.9884410975573954),
 (15200, 16): (12, 11, 1.8254301706001883, 1.825430170600188),
 (1900, 15): (9, 8, 1.6541830766301984, 1.6541830766301981),
 (16145, 20): (12, 11, 1.6233116388185762, 1.623311638818576),
 (2580, 4): (7, 6, 7.126969930959522, 7.126969930959522),
 (1702, 27): (11, 10, 1.3172407839773748, 1.3172407839773748),
 (9875, 29): (13, 13, 1.3732280275029944, 1.3732280275029942),
 (15687, 15): (12, 11, 1.9041565885196923, 1.904156588519692),
 (5774, 16): (12, 11, 1.718273525571221, 1.718273525571221),
 (6186, 2): (5, 4, 78.65112840894274, 78.65112840894277),
 (4476, 23): (12, 11, 1.441233509663115, 1.441233509663115),
 (4161, 24): (12, 11, 1.4151416228042228, 1.4151416228042226),
 (16116, 13): (12, 11, 2.1068575543742956, 2.1068575543742956),
 (12380, 14): (11, 11, 1.9603661231094314, 1.9603661231094311),
 (9736, 19): (13, 12, 1.621491836717726, 1.6214918367177258),
 (8612, 26): (13, 12, 1.4169357394205302, 1.4169357394205302),
 (4586, 7): (9, 8, 3.334740217355978, 3.3347402173559777),
 (5232, 24): (12, 12, 1.428711330576587, 1.4287113305765868),
 (14698, 17): (12, 11, 1.7584613929955697, 1.7584613929955695),
 (4931, 13): (10, 9, 1.923410237452901, 1.9234102374529014),
 (7391, 4): (8, 7, 9.272050761175075, 9.272050761175075),
 (9949, 6): (9, 9, 4.637635073009885, 4.637635073009886),
 (4767, 18): (12, 11, 1.6008364077669808, 1.6008364077669806),
 (16318, 8): (11, 10, 3.3618889684623863, 3.3618889684623863),
 (7610, 28): (13, 12, 1.3760077520394016, 1.3760077520394014),
 (13632, 6): (10, 9, 4.887573066390476, 4.887573066390476),
 (8380, 21): (11, 11, 1.5375213098103222, 1.5375213098103224),
 (7247, 14): (11, 10, 1.88679879448582, 1.8867987944858202),
 (11343, 18): (13, 12, 1.6798196720467486, 1.6798196720467486),
 (6468, 17): (11, 10, 1.6755714114675964, 1.6755714114675964),
 (11801, 6): (10, 9, 4.771480299610415, 4.771480299610415),
 (441, 28): (9, 8, 1.2429230307022932, 1.2429230307022932),
 (15341, 14): (12, 11, 1.9906254301097404, 1.9906254301097404),
 (8501, 20): (11, 10, 1.5720758667453518, 1.5720758667453518),
 (2777, 19): (10, 10, 1.5178918732605664, 1.5178918732605664),
 (14842, 30): (14, 13, 1.3773672345540857, 1.3773672345540857),
 (6149, 28): (11, 10, 1.3655715058830975, 1.3655715058830973),
 (13374, 21): (12, 11, 1.5721306482454025, 1.5721306482454025),
 (9947, 30): (13, 13, 1.3591156205784112, 1.359115620578415),
 (14423, 16): (12, 11, 1.8194535606369682, 1.8194535606369682),
 (9341, 31): (13, 12, 1.3430036888404402, 1.3430036888404402),
 (14558, 7): (10, 9, 3.9330441035217714, 3.9330441035217714),
 (152, 16): (8, 7, 1.3688795144738382, 1.368879514473838),
 (13593, 18): (12, 11, 1.6967920812890351, 1.6967920812890351),
 (2834, 7): (8, 8, 3.113149507653915, 3.1131495076539637),
 (11545, 14): (11, 10, 1.950612466604169, 1.9506124666041689),
 (12416, 21): (12, 11, 1.566576147387506, 1.5665761473875057),
 (8998, 6): (9, 8, 4.560624662646501, 4.560624662646504),
 (5245, 27): (11, 10, 1.3733092699013858, 1.3733092699013856),
 (5693, 29): (11, 10, 1.3473937347050562, 1.3473937347050562),
 (3508, 26): (10, 10, 1.3688266297365765, 1.368826629736599),
 (16237, 9): (11, 10, 2.936526854011741, 2.936526854011741),
 (2911, 6): (8, 7, 3.778682197915392, 3.7786821979153924),
 (387, 22): (10, 9, 1.311061987204142, 1.311061987204142),
 (3324, 4): (7, 7, 7.593032412784651, 7.593032412784651),
 (15300, 22): (12, 11, 1.5495773040868939, 1.5495773040868939),
 (5469, 26): (11, 10, 1.3924053694535468, 1.392405369453547),
 (1195, 13): (8, 8, 1.7247279767538015, 1.724727976753898),
 (7998, 13): (11, 10, 1.9963162339549785, 1.9963162339549787)}

In [1435]: stats1
Out[1435]:
{(10882, 3): (4, 22.159990703206965, 22.159990703206965),
 (6673, 28): (6, 1.3695657835909767, 1.3695657835909767),
 (4803, 10): (4, 2.3342709144708604, 2.3342709144708604),
 (1802, 27): (5, 1.3200291160996098, 1.3200291160996294),
 (8380, 15): (4, 1.8262053100463662, 1.826205310046366),
 (12898, 21): (5, 1.569419919895426, 1.5694199198954262),
 (10227, 2): (4, 101.12863096077193, 101.12863096077193),
 (4857, 25): (5, 1.4042832772764529, 1.4042832772765208),
 (1351, 12): (4, 1.8234251715932501, 1.8234251715932501),
 (10180, 16): (3, 1.7802632882832108, 1.7802632882832108),
 (6948, 28): (6, 1.37154252932099, 1.3715425293209902),
 (13901, 10): (4, 2.5959994991170756, 2.5959994991171),
 (7513, 21): (5, 1.529545998990047, 1.529545998990047),
 (7902, 18): (4, 1.6464211857566509, 1.6464211857566777),
 (3277, 31): (6, 1.2983819395882321, 1.2983819395882321),
 (3499, 10): (4, 2.261488555258189, 2.2614885552581887),
 (15234, 30): (6, 1.3785646304878574, 1.3785646304878574),
 (5739, 13): (5, 1.9459928850146935, 1.9459928850146935),
 (1823, 24): (5, 1.3673072329614473, 1.367307232961453),
 (15105, 16): (4, 1.8247150144295399, 1.8247150144295399),
 (16215, 12): (3, 2.2429852247131876, 2.2429852247131876),
 (15844, 20): (5, 1.6217848596088165, 1.6217848596088162),
 (15677, 26): (5, 1.4499608216310222, 1.4499608216310922),
 (11839, 22): (5, 1.5316187797414815, 1.5316187797414818),
 (10163, 26): (6, 1.4259891723870766, 1.4259891723870766),
 (1550, 18): (5, 1.5039751132184096, 1.5039751132184096),
 (15194, 5): (4, 6.8601628355219795, 6.860162835521979),
 (15612, 24): (5, 1.4952969217858556, 1.495296921785857),
 (9469, 12): (4, 2.1446611088057317, 2.144661108805732),
 (4030, 20): (5, 1.5144859625611744, 1.5144859625611742),
 (11729, 3): (4, 22.720627875592783, 22.72062787559279),
 (12709, 29): (6, 1.3852274277113097, 1.3852274277113097),
 (12263, 31): (6, 1.354846884624788, 1.354846884624788),
 (6372, 9): (4, 2.6466548424778766, 2.6466548424779086),
 (7119, 5): (4, 5.8949998528103436, 5.8949998528103436),
 (10737, 27): (6, 1.4102365332846034, 1.4102365332846034),
 (2231, 11): (4, 2.01562182326949, 2.0156218232694902),
 (412, 9): (4, 1.952289125066342, 1.9522891250663446),
 (8417, 5): (4, 6.095810609109851, 6.095810609109851),
 (6759, 31): (6, 1.3290600231725829, 1.3290600231725826),
 (2207, 23): (5, 1.3975994148304356, 1.3975994148304371),
 (4755, 16): (4, 1.6975473914419268, 1.697547391441927),
 (7978, 13): (4, 1.995931787083301, 1.9959317870833602),
 (14957, 19): (4, 1.658550339552885, 1.658550339552935),
 (745, 28): (5, 1.2664178106847905, 1.2664178106847914),
 (2696, 15): (4, 1.6932249497186587, 1.6932249497186587),
 (5484, 7): (4, 3.421029175738748, 3.421029175738748),
 (15410, 10): (4, 2.6228911266357167, 2.6228911266357264),
 (315, 10): (3, 1.7775877772276876, 1.7775877772276876),
 (9252, 13): (4, 2.0188081438649, 2.0188081438649035),
 (1562, 27): (5, 1.313059731029047, 1.3130597310290748),
 (9803, 8): (4, 3.1544225980493534, 3.154422598049354),
 (14443, 16): (4, 1.8196111450479415, 1.8196111450479413),
 (5033, 23): (5, 1.4486017214956801, 1.4486017214956872),
 (16175, 2): (4, 127.18097341976905, 127.18097341976903),
 (15125, 5): (4, 6.853920721113647, 6.853920721113645),
 (16292, 19): (4, 1.6660301757087943, 1.666030175708797),
 (11486, 20): (5, 1.5959101637580406, 1.5959101637580406),
 (13824, 7): (4, 3.9040835527337374, 3.9040835527337383),
 (8604, 3): (4, 20.491172086350442, 20.491172086350442),
 (1225, 3): (4, 10.699874805650794, 10.699874805650795),
 (9163, 21): (5, 1.54407524840236, 1.5440752484023603),
 (7833, 21): (5, 1.53258704058841, 1.53258704058841),
 (8425, 24): (5, 1.4573551932369144, 1.4573551932369178)}

On average for the first method roughly 12 iterations of binary search and 12 iterations of Newton's method are needed to get the error below 10-13, and 6 iterations of Newton's method are needed for the second method to get the same accuracy.

Are there ways to make the code run faster for the same number of iterations, or ways to speed up the convergence rate of the math involved?


This is not an assignment. It is a self-imposed programming challenge.

Upvotes: 2

Views: 432

Answers (2)

user21508463
user21508463

Reputation:

You don't say what range of values your function should cover, nor the statistical distribution of these values, so we have to suggest a general-purpose method.

The strategy is to first find a good initial approximation for Newton's iterations by working on the order of magnitude.

If you have access to the exponent of the floating-point representation (by hacking the binary representation), a good starting value is √2/2 times 2^(exponent/n). We choose √2/2 rather than 1 to be centered on [1/2, 1). For efficiency, you can precompute these constants for all exponents and all root orders. Anyway, to save space, it can be better to decompose the exponent over n in integer quotient and remainder.

If the exponent is not available, then you can search for it by successive doublings (or halvings), starting from 1 (so 1=2^0, 2=2^1, 4=2^2, 8=2^3...). This amounts to a linear search among the exponents. Yet more efficient is to work with squarings, and implement an exponential search among the exponents (you are doubling the powers each time, 2=2^1, 4=2^2, 16=2^4, 256=2^8...). After you have found a possible range of exponents, revert to linear search. You can also optimize with precomputed values.

Finally, you can start with Newton's iterations. For the case of the square-root, you can work with the inverse square root to avoid division. Unfortunately, this does not generalize to hight orders.

Last but not least, it may be beneficial to determine the number of iterations required in the worst case (with the worst initial approximation), and always use this number, rather than test for convergence with a certain tolerance.

Upvotes: 2

Matt Timmermans
Matt Timmermans

Reputation: 59194

Using binary search to get an initial guess for nth_root is not particularly efficient.

You can use a hack on the floating point representation similar to the one in fast_nth..., before proceeding with a proper Newton's method implementation.

Something like:

constexpr int32_t ONE_BITS = std::bit_cast<int32_t>(1.0f);

inline float nth_root(float base, int n) {
    int32_t x_bits = std::bit_cast<int32_t>(base);
    x_bits = (x_bits - ONE_BITS)/n + ONE_BITS;
    // first guess
    float x = std::bit_cast<float>(x_bits);

The important effect of this is just that it divides the floating point exponent by n.

Upvotes: 2

Related Questions