Aequitas
Aequitas

Reputation: 2265

How can I make branchless code and how does it work?

Related to this answer.

In the above answer, it's mentioned how you can avoid branch prediction fails by avoiding branches.

The user demonstrates this by replacing:

if (data[c] >= 128)
{
    sum += data[c];
}

With:

int t = (data[c] - 128) >> 31;
sum += ~t & data[c];

How are these two equivalent (for the specific data set, not strictly equivalent)?

What are some general ways I can do similar things in similar situations? Would it always be by using >> and ~?

Upvotes: 38

Views: 21408

Answers (3)

Aki Suihkonen
Aki Suihkonen

Reputation: 20027

Branchless code means typically evaluating all possible outcomes of a conditional statement with a weight from the set [0, 1], so that the Sum{ weight_i } = 1. Most of the calculations are essentially discarded. Some optimization can result from the fact, that E_i doesn't have to be correct when the corresponding weight w_i (or mask m_i) is zero.

  result = (w_0 * E_0) + (w_1 * E_1) + ... + (w_n * E_n)    ;; or
  result = (m_0 & E_0) | (m_1 & E_1) | ... | (m_n * E_n)

where m_i stands for a bitmask.

Speed can be achieved also through parallel processing of E_i with a horizontal collapse.

This is contradictory to the semantics of if (a) b; else c; or it's ternary shorthand a ? b : c, where only one expression out of [b, c] is evaluated.

Thus ternary operation is no magic bullet for branchless code. A decent compiler produces branchless code equally for

t = data[n];
if (t >= 128) sum+=t;

vs.

    movl    -4(%rdi,%rdx), %ecx
    leal    (%rax,%rcx), %esi
    addl    $-128, %ecx
    cmovge  %esi, %eax

Variations of branchless code include presenting the problem through other branchless non-linear functions, such as ABS, if present in the target machine.

e.g.

 2 * min(a,b) = a + b - ABS(a - b),
 2 * max(a,b) = a + b + ABS(a - b)

or even:

 ABS(x) = sqrt(x*x)      ;; caveat -- this is "probably" not efficient

In addition to << and ~, it may be equally beneficial to use bool and !bool instead of (possibly undefined) (int >> 31). Likewise, if the condition evaluates as [0, 1], one can generate a working mask with negation:

-[0, 1] = [0, 0xffffffff]  in 2's complement representation

Upvotes: 11

Louis Wasserman
Louis Wasserman

Reputation: 198113

int t = (data[c] - 128) >> 31;

The trick here is that if data[c] >= 128, then data[c] - 128 is nonnegative, otherwise it is negative. The highest bit in an int, the sign bit, is 1 if and only if that number is negative. >> is a shift that extends the sign bit, so shifting right by 31 makes the whole result 0 if it used to be nonnegative, and all 1 bits (which represents -1) if it used to be negative. So t is 0 if data[c] >= 128, and -1 otherwise. ~t switches these possibilities, so ~t is -1 if data[c] >= 128, and 0 otherwise.

x & (-1) is always equal to x, and x & 0 is always equal to 0. So sum += ~t & data[c] increases sum by 0 if data[c] < 128, and by data[c] otherwise.

Many of these tricks can be applied elsewhere. This trick can certainly be generally applied to have a number be 0 if and only if one value is greater than or equal to another value, and -1 otherwise, and you can mess with it some more to get <=, <, and so on. Bit twiddling like this is a common approach to making mathematical operations branch-free, though it's certainly not always going to be built out of the same operations; ^ (xor) and | (or) also come into play sometimes.

Upvotes: 44

apangin
apangin

Reputation: 98332

While Louis Wasserman's answer is correct, I want to show you a more general (and much clearer) way to write branchless code. You can just use ? : operator:

    int t = data[c];
    sum += (t >= 128 ? t : 0);

JIT compiler sees from the execution profile that the condition is poorly predicted here. In such cases the compiler is smart enough to replace a conditional branch with a conditional move instruction:

    mov    0x10(%r14,%rbp,4),%r9d  ; load R9d from array
    cmp    $0x80,%r9d              ; compare with 128
    cmovl  %r8d,%r9d               ; if less, move R8d (which is 0) to R9d

You can verify yourself that this version works equally fast for both sorted and unsorted array.

Upvotes: 14

Related Questions