Stephane
Stephane

Reputation: 419

Fastest way to calculate a digit-sum for a large number (as a decimal string)

I use gmplib to get big number and I calculate the numeric value (sum of the digits: 123 -> 6, 74 -> 11 -> 2)

Here is what I did :

unsigned short getnumericvalue(const char *in_str)
{
    unsigned long number = 0;
    const char *ptr = in_str;
     
     do {
         if (*ptr != '9') number += (*ptr - '0'); // Exclude '9'
         ptr++;
     } while (*ptr != 0);
     
     unsigned short reduced = number % 9;
    
     return reduced == 0 ? 9 : reduced;
}

It works well but is there a faster way to do it on a Xeon w-3235?

Upvotes: 2

Views: 440

Answers (2)

fuz
fuz

Reputation: 92976

You can use code like the following. The general idea of the algorithm is:

  1. process data bytewise until we reach cacheline alignment
  2. read one cacheline at a time, check for end of string, and add digits to 8 accumulators
  3. reduce 8 accumulators to one and add counts from head
  4. process remainder bytewise

Note that the code below has not been tested.

        // getnumericvalue(ptr)
        .section .text
        .type getnumericvalue, @function
        .globl getnumericvalue
getnumericvalue:
        xor %eax, %eax          // digit counter

        // process string until we reach cache-line alignment
        test $64-1, %dil        // is ptr aligned to 64 byte?
        jz 0f

1:      movzbl (%rdi), %edx     // load a byte from the string
        inc %rdi                // advance pointer
        test %edx, %edx         // is this the NUL byte?
        jz .Lend                // if yes, finish this function
        sub $'0', %edx          // turn ASCII character into digit
        add %edx, %eax          // and add to counter
        test $64-1, %dil        // is ptr aligned to 64 byte?
        jnz 1b                  // if not, process more data

        // process data in cache line increments until the end
        // of the string is found somewhere
0:      vpbroadcastd zero(%rip), %zmm1  // mask of '0' characters
        vpxor %xmm3, %xmm3, %xmm3       // vectorised digit counter

        vmovdqa32 (%rdi), %zmm0         // load one cache line from the string
        vptestmb %zmm0, %zmm0, %k0      // clear k0 bits if any byte is NUL
        kortestq %k0, %k0               // clear CF if a NUL byte is found
        jnc 0f                          // skip loop if a NUL byte is found

        .balign 16
1:      add $64, %rdi                   // advance pointer
        vpsadbw %zmm1, %zmm0, %zmm0     // sum groups of 8 bytes into 8 words
                                        // also subtracts '0' from each byte
        vpaddq %zmm3, %zmm0, %zmm3      // add to counters
        vmovdqa32 (%rdi), %zmm0         // load one cache line from the string
        vptestmb %zmm0, %zmm0, %k0      // clear k0 bits if any byte is NUL
        kortestq %k0, %k0               // clear CF if a NUL byte is found
        jc 1b                           // go on unless a NUL byte was found

        // reduce 8 vectorised counters into rdx
0:      vextracti64x4 $1, %zmm3, %ymm2  // extract high 4 words
        vpaddq %ymm2, %ymm3, %ymm3      // and add them to the low words
        vextracti128 $1, %ymm3, %xmm2   // extract high 2 words
        vpaddq %xmm2, %xmm3, %xmm3      // and add them to the low words
        vpshufd $0x4e, %xmm3, %xmm2     // swap qwords into xmm2
        vpaddq %xmm2, %xmm3, %xmm3      // and add to xmm0
        vmovq %xmm3, %rdx               // move digit counter back to rdx
        add %rdx, %rax                  // and add to counts from scalar head

        // process tail
1:      movzbl (%rdi), %edx     // load a byte from the string
        inc %rdi                // advance pointer
        test %edx, %edx         // is this the NUL byte?
        jz .Lend                // if yes, finish this function
        sub $'0', %edx          // turn ASCII character into digit
        add %rdx, %rax          // and add to counter
        jnz 1b                  // if not, process more data

.Lend:  xor %edx, %edx          // zero-extend RAX into RDX:RAX
        mov $9, %ecx            // divide by 9
        div %rcx                // perform division
        mov %edx, %eax          // move remainder to result register
        test %eax, %eax         // is the remainder zero?
        cmovz %ecx, %eax        // if yes, set remainder to 9
        vzeroupper              // restore SSE performance
        ret                     // and return
        .size getnumericvalue, .-getnumericvalue

        // constants
        .section .rodata
        .balign 4
zero:   .byte '0', '0', '0', '0'

Upvotes: 3

chqrlie
chqrlie

Reputation: 144695

Here is a portable solution:

  • it handles the first few digits naively until ptr is properly aligned.
  • then it loops reading 8 digits at a time and sums the digits pairwise into an accumulator. Up to 28 such operations can be performed before splitting the 64-bit accumulator into number.
  • the test for termination verifies that all digits in the pack have high nibbles equal to 3.
  • the remaining digits are handled one by one.
#include <stdio.h>
#include <stdlib.h>
#include <time.h>

unsigned getnumericvalue_simple(const char *in_str) {
    unsigned long number = 0;
    const char *ptr = in_str;

    do {
        if (*ptr != '9') number += (*ptr - '0'); // Exclude '9'
        ptr++;
    } while (*ptr != 0);

    return number <= 9 ? number : ((number - 1) % 9) + 1;
}

unsigned getnumericvalue_naive(const char *ptr) {
    unsigned long number = 0;

    while (*ptr) {
        number += *ptr++ - '0';
    }
    return number ? 1 + (number - 1) % 9 : 0;
}

unsigned getnumericvalue_parallel(const char *ptr) {
    unsigned long long number = 0;
    unsigned long long pack1, pack2;

    /* align source on ull boundary */
    while ((uintptr_t)ptr & (sizeof(unsigned long long) - 1)) {
        if (*ptr == '\0')
            return number ? 1 + (number - 1) % 9 : 0;
        number += *ptr++ - '0';
    }

    /* scan 8 bytes at a time */
    for (;;) {
        pack1 = 0;
#define REP8(x) x;x;x;x;x;x;x;x
#define REP28(x) x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x;x
        REP28(pack2 = *(const unsigned long long *)(const void *)ptr;
              pack2 -= 0x3030303030303030;
              if (pack2 & 0xf0f0f0f0f0f0f0f0)
                  break;
              ptr += sizeof(unsigned long long);
              pack1 += pack2);
        REP8(number += pack1 & 0xFF; pack1 >>= 8);
    }
    REP8(number += pack1 & 0xFF; pack1 >>= 8);

    /* finish trailing bytes */
    while (*ptr) {
        number += *ptr++ - '0';
    }
    return number ? 1 + (number - 1) % 9 : 0;
}

int main(int argc, char *argv[]) {
    clock_t start;
    unsigned naive_result, simple_result, parallel_result;
    double naive_time, simple_time, parallel_time;
    int digits = argc < 2 ? 1000000 : strtol(argv[1], NULL, 0);
    char *p = malloc(digits + 1);
    for (int i = 0; i < digits; i++)
        p[i] = "0123456789123456"[i & 15];
    p[digits] = '\0';

    start = clock();
    simple_result = getnumericvalue_simple(p);
    simple_time = (clock() - start) * 1000.0 / CLOCKS_PER_SEC;

    start = clock();
    naive_result = getnumericvalue_naive(p);
    naive_time = (clock() - start) * 1000.0 / CLOCKS_PER_SEC;

    start = clock();
    parallel_result = getnumericvalue_parallel(p);
    parallel_time = (clock() - start) * 1000.0 / CLOCKS_PER_SEC;

    printf("simple:   %d digits -> %u, %7.3f msec\n", digits, simple_result, simple_time);
    printf("naive:    %d digits -> %u, %7.3f msec\n", digits, naive_result, naive_time);
    printf("parallel: %d digits -> %u, %7.3f msec\n", digits, parallel_result, parallel_time);

    return 0;
}

The timings for 100 million digits:

simple:   100000000 digits -> 3, 100.380 msec
naive:    100000000 digits -> 3,  98.128 msec
parallel: 100000000 digits -> 3,   7.848 msec

Note that the extra test in the posted version is incorrect as getnumericvalue("9") should produce 9, not 0.

The parallel version is 12 times faster than the simple version.

More performance might be obtained using AVX instructions via compiler intrinsics or even assembly language, but for very large arrays, memory bandwidth seems to be the limiting factor.

Upvotes: 2

Related Questions