Łukasz Lew
Łukasz Lew

Reputation: 50238

What is the fastest way to compute large dot products?

Consider this snippet:

double dot(double* a, double* b, int n) {
  double sum = 0;
  for (int i = 0; i < n; ++i) sum += a[i] * b[i];
  return sum;
}

How can I speed it up using intrinsics or assembler?

Notes:

Upvotes: 1

Views: 2124

Answers (1)

Paul R
Paul R

Reputation: 212929

Here is a simple SSE implementation:

#include "pmmintrin.h"

__m128d vsum = _mm_set1_pd(0.0);
double sum = 0.0;
int k;

// process 2 elements per iteration
for (k = 0; k < n - 1; k += 2)
{
    __m128d va = _mm_loadu_pd(&a[k]);
    __m128d vb = _mm_loadu_pd(&b[k]);
    __m128d vs = _mm_mul_pd(va, vb);
    vsum = _mm_add_pd(vsum, vs);
}

// horizontal sum of 2 partial dot products
vsum = _mm_hadd_pd(vsum, vsum);
_mm_store_sd(&sum, vsum);

// clean up any remaining elements
for ( ; k < n; ++k)
{
    sum += a[k] * b[k];
}

Note that if you can guarantee that a and b are 16 byte aligned then you can use _mm_load_pd rather than _mm_loadu_pd which may help performance, particularly on older (pre Nehalem) CPUs.

Note also that for loops such as this where the are very few arithmetic instructions relative to the number of loads then performance may well be limited by memory bandwidth and the expected speed-up from vectorization may not be realised in practice.


If you want to target CPUs with AVX then it's a fairly straightforward conversion from the above SSE implementation to process 4 elements per iteration rather than 2:

#include "immintrin.h"

__m256d vsum = _mm256_set1_pd(0.0);
double sum = 0.0;
int k;

// process 4 elements per iteration
for (k = 0; k < n - 3; k += 4)
{
    __m256d va = _mm256_loadu_pd(&a[k]);
    __m256d vb = _mm256_loadu_pd(&b[k]);
    __m256d vs = _mm256_mul_pd(va, vb);
    vsum = _mm256_add_pd(vsum, vs);
}

// horizontal sum of 4 partial dot products
vsum = _mm256_hadd_pd(_mm256_permute2f128_pd(vsum, vsum, 0x20), _mm256_permute2f128_pd(vsum, vsum, 0x31));
vsum = _mm256_hadd_pd(_mm256_permute2f128_pd(vsum, vsum, 0x20), _mm256_permute2f128_pd(vsum, vsum, 0x31));
_mm256_store_sd(&sum, vsum);

// clean up any remaining elements
for ( ; k < n; ++k)
{
    sum += a[k] * b[k];
}

Upvotes: 7

Related Questions