Khai Nguyen
Khai Nguyen

Reputation: 103

How to optimization long series of If/then conditional expressions - SIMD

I'm using SIMD for improving the performance of C code, but I encountered a function with many if/then condition as below:

if (Di <= -T3) return  -4;
if (Di <= -T2) return  -3;
if (Di <= -T1) return  -2;
if (Di < -NEAR)  return  -1;
if (Di <=  NEAR) return   0;
if (Di < T1)   return   1;
if (Di < T2)   return   2;
if (Di < T3)   return   3;

return  4;

Using Intel intrinsic functions which were supported in VC++ compiler give out a slower processing time.

Thus are there any better way to optimize this long series of conditional expression ?

Upvotes: 1

Views: 260

Answers (2)

stgatilov
stgatilov

Reputation: 5533

I assume several things:

  1. You deal with int32 data (it can easily be changed to float32, though).
  2. You can pass 4 values to your function at once (not just one). That's what people usually mean by vectorization.
  3. Constants are sorted, i.e. 0 < NEAR < T1 < T2 < T3.

Here is a vectorized function:

__m128i func4(__m128i D) {
  __m128i cmp_m3 = _mm_cmpgt_epi32(D, _mm_set1_epi32(-T3));
  __m128i cmp_m2 = _mm_cmpgt_epi32(D, _mm_set1_epi32(-T2));
  __m128i cmp_m1 = _mm_cmpgt_epi32(D, _mm_set1_epi32(-T1));
  __m128i cmp_p0 = _mm_cmpgt_epi32(D, _mm_set1_epi32(NEAR));
  __m128i reduce_true = _mm_add_epi32(_mm_add_epi32(cmp_m3, cmp_m2), _mm_add_epi32(cmp_m1, cmp_p0));
  __m128i cmp_m0 = _mm_cmplt_epi32(D, _mm_set1_epi32(-NEAR));
  __m128i cmp_p1 = _mm_cmplt_epi32(D, _mm_set1_epi32(T1));
  __m128i cmp_p2 = _mm_cmplt_epi32(D, _mm_set1_epi32(T2));
  __m128i cmp_p3 = _mm_cmplt_epi32(D, _mm_set1_epi32(T3));
  __m128i reduce_false = _mm_add_epi32(_mm_add_epi32(cmp_p3, cmp_p2), _mm_add_epi32(cmp_p1, cmp_m0));
  return _mm_sub_epi32(reduce_false, reduce_true);
}

If the input data is random, then it works 11 times faster than your original version on Ivy Bridge with MSVC2013 x64:

Time = 4.436   (-39910000)
Time = 0.409   (-39910000)

The full code with testing is available here.

The idea is rather simple. You can see non-vectorized version of the proposed solution in the function funcX following the link above. It may explain everything better than words.

We take a register D as input, it contains 4 packed values. Then we compare it against all the 8 constants you have with _mm_cmp* intrinsic. This comparison produces 8 bitmasks cmp_pX, cmp_mX. In a bitmask all the bits corresponding to a number are either zero or one. 32 zero bits are set for each comparison, which was false. If the condition of comparison was true, then the 32 bits are set to 1.

Now recall that 32-bit integer with all one bits is -1 in signed representation. When we add four results of comparison together, we obtain a pack of counts negated. Finally, we take the difference of two counts, and it is the desired result.

P.S. Here is the assembly code generated for inner loop:

movdqa  xmm3, XMMWORD PTR [rcx]
movdqa  xmm4, xmm10
movdqa  xmm0, xmm9
add rcx, 16
pcmpgtd xmm0, xmm3
pcmpgtd xmm4, xmm3
paddd   xmm4, xmm0
movdqa  xmm2, xmm3
movdqa  xmm1, xmm8
pcmpgtd xmm1, xmm3
pcmpgtd xmm2, xmm14
movdqa  xmm0, xmm7
pcmpgtd xmm0, xmm3
paddd   xmm1, xmm0
paddd   xmm4, xmm1
movdqa  xmm0, xmm3
movdqa  xmm1, xmm3
pcmpgtd xmm1, xmm12
pcmpgtd xmm0, xmm13
pcmpgtd xmm3, xmm11
paddd   xmm1, xmm3
paddd   xmm2, xmm0
paddd   xmm2, xmm1
psubd   xmm4, xmm2
paddd   xmm4, xmm5
movdqa  xmm5, xmm4
cmp rcx, r15
jl  SHORT $LL3@main

Upvotes: 6

tsapelman
tsapelman

Reputation: 1

You can try to get rid of conditions at all and measure the time again. Your code

if (Di <= -T3) return  -4;
if (Di <= -T2) return  -3;
if (Di <= -T1) return  -2;
if (Di < -NEAR)  return  -1;
if (Di <=  NEAR) return   0;
if (Di < T1)   return   1;
if (Di < T2)   return   2;
if (Di < T3)   return   3;

return  4;

can be transformed to unconditional form:

return
    (Di <= -T3)*(-4) + (Di > -T3) * (
    (Di <= -T2)*(-3) + (Di > -T2) * (
    (Di <= -T1)*(-2) + (Di > -T1) * (
    (Di < -NEAR)*(-1) + (Di >= -NEAR) * (
    (Di <=  NEAR)*0 + (Di > NEAR) * (
    (Di < T1)*1 + (Di >= T1) * (
    (Di < T2)*2 + (Di >= T2) * (
    (Di < T3)*3 + (Di >= T3) * (
    4
    ))))))));

Probably, you can further optimize this code, having some knowledge about possible content of your variables.

Upvotes: 0

Related Questions