E_1996
E_1996

Reputation: 89

C++ AVX2 custom functions (e.g., "exp") not working on Windows (but work on Linux)

this post is somewhat related to this post I did a couple of days ago:

C++ AVX2 Function Pointers/std::function not working on Windows (but work on Linux)

Now, since then, thanks to the useful comments I resolved that issue; by using:

// Simple test function that just multiplies vector by 2
__m256d  test_simple_AVX2(const __m256d x) {
   
   ALIGN32 const __m256d two = _mm256_set1_pd(2.0);
   ALIGN32 const __m256d res = _mm256_mul_pd(x, two);
   return res; 
   
 }

Rather than:

// Simple test function that just multiplies vector by 2
__m256d  test_simple_AVX2(const __m256d x) {
   
   const __m256d two = _mm256_set1_pd(2.0);
   const __m256d res = _mm256_mul_pd(x, two);
   return res; 
   
 }

Where ALIGN32 is defined as:

#ifdef _MSC_VER
   #define ALIGN32 __declspec(align(32))
#else
   #define ALIGN32 alignas(32)
#endif
 

However, my more complex AVX2 functions - which work fine on Linux - still do not work (even when called directly) and crash - I get an aborted session and am using C++ via Rcpp.

For instance here's my exp function:

// Adapted from: https://stackoverflow.com/questions/48863719/fastest-implementation-of-exponential-function-using-avx
// added   (optional) extra degree(s) for poly approx (oroginal float fn had 4 degrees) - using "minimaxApprox" R package to find coefficient terms
// R code:    minimaxApprox::minimaxApprox(fn = exp, lower = -0.346573590279972643113, upper = 0.346573590279972643113, degree = 5, basis ="Chebyshev")
inline    __m256d fast_exp_1_wo_checks_AVX2 VECTORCALL(const __m256d x)  { 
  
    _mm256_zeroupper();  // Reset AVX state
  
    ALIGN32  __m256d const x_aligned = x;
    
    ALIGN32  __m256d const exp_l2e = _mm256_set1_pd (1.442695040888963387); /* log2(e) */
    ALIGN32  __m256d const exp_l2h = _mm256_set1_pd (-0.693145751999999948367); /* -log(2)_hi */
    ALIGN32  __m256d const exp_l2l = _mm256_set1_pd (-0.00000142860676999999996193); /* -log(2)_lo */
    
    // /* coefficients for core approximation to exp() in [-log(2)/2, log(2)/2] */
    ALIGN32  __m256d const exp_c0 =     _mm256_set1_pd(0.00000276479776161191821278);
    ALIGN32  __m256d const exp_c1 =     _mm256_set1_pd(0.0000248844480527491290235);
    ALIGN32  __m256d const exp_c2 =     _mm256_set1_pd(0.000198411488032534342194);
    ALIGN32  __m256d const exp_c3 =     _mm256_set1_pd(0.00138888017711994078175);
    ALIGN32  __m256d const exp_c4 =     _mm256_set1_pd(0.00833333340524595143906);
    ALIGN32  __m256d const exp_c5 =     _mm256_set1_pd(0.0416666670404215802592);
    ALIGN32  __m256d const exp_c6 =     _mm256_set1_pd(0.166666666664891632843);
    ALIGN32  __m256d const exp_c7 =     _mm256_set1_pd(0.499999999994389376923);
    ALIGN32  __m256d const exp_c8 =     _mm256_set1_pd(1.00000000000001221245);
    ALIGN32  __m256d const exp_c9 =     _mm256_set1_pd(1.00000000000001332268);
    
    ALIGN32  __m256d const input  = x_aligned;
    
    /* exp(x) = 2^i * e^f; i = rint (log2(e) * a), f = a - log(2) * i */
    ALIGN32  __m256d const t = _mm256_mul_pd(x_aligned, exp_l2e);      /* t = log2(e) * a */
    ///  const __m256i i = _mm256_cvttpd_epi32(t);       /* i = (int)rint(t) */
    ALIGN32  __m256i const i = avx2_cvtpd_epi64(t);  
    // const __m256d x_2 = _mm256_round_pd(t, _MM_FROUND_TO_NEAREST_INT) ; // ((0<<4)| _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC|_MM_FROUND_NO_EXC));
    ALIGN32  __m256d const x_2 = _mm256_round_pd(t, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
    ALIGN32  __m256d const f0 = _mm256_fmadd_pd(x_2, exp_l2h, input);
    ALIGN32  __m256d const f =  _mm256_fmadd_pd(x_2, exp_l2l, f0);  /* a - log(2)_hi * r */    /* f = a - log(2)_hi * r - log(2)_lo * r */
    
    /* p ~= exp (f), -log(2)/2 <= f <= log(2)/2 */
    ALIGN32 __m256d p = exp_c0;
    p = _mm256_fmadd_pd(p, f, exp_c1);
    p = _mm256_fmadd_pd(p, f, exp_c2);
    p = _mm256_fmadd_pd(p, f, exp_c3);
    p = _mm256_fmadd_pd(p, f, exp_c4);
    p = _mm256_fmadd_pd(p, f, exp_c5);
    p = _mm256_fmadd_pd(p, f, exp_c6);
    p = _mm256_fmadd_pd(p, f, exp_c7);
    p = _mm256_fmadd_pd(p, f, exp_c8);
    p = _mm256_fmadd_pd(p, f, exp_c9);
    
    // ALIGN32  __m256d const res = fast_ldexp_2(p, i);   /* exp(x) = 2^i * p */
    ALIGN32  __m256d const res = fast_ldexp(p, i);   /* exp(x) = 2^i * p */
    _mm256_zeroupper();  // Reset AVX state
    return  res;  
  
} 

Which relies on the following AVX2 functions:


// Helper function for AVX2 64-bit conversion 
inline __m256i avx2_cvtpd_epi64 VECTORCALL(const __m256d x) {
  
     ALIGN32  __m256d const x_aligned = x;
 
     // Extract doubles and convert to int64 one at a tim
     ALIGN32 int64_t result[4];
     ALIGN32 double temp[4];

     _mm256_storeu_pd(temp, x_aligned);
     
     for(int i = 0; i < 4; i++) {
          #ifdef _MSC_VER
                 result[i] = (int64_t)_mm_cvtsd_si64(_mm_load_sd(&temp[i]));  // Use SSE2 conversion
          #else
                 result[i] = (int64_t)std::llrint(temp[i]);     // Use llrint for proper rounding
          #endif
          // result[i] = (int64_t)_mm_cvtsd_si64(_mm_load_sd(&temp[i]));  // Use SSE2 conversion
     }
     
     //// return _mm256_load_si256((__m256i*)result);
     ALIGN32  __m256i const res = _mm256_load_si256(reinterpret_cast<const __m256i*>(static_cast<void*>(result))); 
     return res;
 
}



inline    __m256d   fast_ldexp VECTORCALL(  const __m256d AVX_a,
                                            const __m256i AVX_i) {
  
    _mm256_zeroupper();  // Reset AVX state
  
    ALIGN32  __m256d const AVX_a_aligned = AVX_a;
    ALIGN32  __m256i const AVX_i_aligned = AVX_i;
  
    ALIGN32 const uint64_t shift_val_52 = 52;
    ALIGN32  __m256i const shifted = _mm256_slli_epi64(AVX_i_aligned, shift_val_52);
    ALIGN32  __m256i const a_bits = _mm256_castpd_si256(AVX_a_aligned);
    ALIGN32  __m256i const result = _mm256_add_epi64(shifted, a_bits);
    
    ALIGN32  __m256d const res = _mm256_castsi256_pd(result);
    _mm256_zeroupper();  // Reset AVX state
    return res;
  
}

I was wondering if anybody has any ideas why the simple test AVX2 function works on Windows but why the exp() function only works on Linux? Is it still an alignment issue?

More info: I'm using C++ via Rcpp. compiler: g++ compiler flags: -O3 -march=znver3 -mtune=znver3 -fPIC -D_REENTRANT -DSTAN_THREADS -pthread -fpermissive -mfma -mavx -mavx2 -flarge-source-files

Upvotes: 0

Views: 105

Answers (0)

Related Questions