degski
degski

Reputation: 672

How to implement lane crossing logical bit-wise shift/rotate (left and right) in AVX2

How to implement lane crossing logical bit-wise shift (left and right) in AVX2? I want to shift a whole __m256i as if it was a single 256-bit integer, with no element or lane boundaries.


An answer on another Q&A looked useful but turned out to actually be about byte-shifts, using _mm256_alignr_epi8 and _mm256_permute2x128_si256 with operands that depend on the compile-time-constant shift count. (See the revision history of this question for a full test program written before realizing it was just byte shifts, so only useful for bit-shift counts that are multiples of 8.)

Upvotes: 7

Views: 1179

Answers (2)

degski
degski

Reputation: 672

The following code implements lane-crossing logical bit-wise shift/rotate (left and right) in AVX2:

// Prototypes...

__m256i _mm256_sli_si256 ( __m256i, int );
__m256i _mm256_sri_si256 ( __m256i, int );
__m256i _mm256_rli_si256 ( __m256i, int );
__m256i _mm256_rri_si256 ( __m256i, int );


// Implementations...

__m256i left_shift_000_063 ( __m256i a, int n ) { // 6

    return _mm256_or_si256 ( _mm256_slli_epi64 ( a, n ), _mm256_blend_epi32 ( _mm256_setzero_si256 ( ), _mm256_permute4x64_epi64 ( _mm256_srli_epi64 ( a, 64 - n ), _MM_SHUFFLE ( 2, 1, 0, 0 ) ), _MM_SHUFFLE ( 3, 3, 3, 0 ) ) );
}

__m256i left_shift_064_127 ( __m256i a, int n ) { // 7

    __m256i b = _mm256_slli_epi64 ( a, n );
    __m256i d = _mm256_permute4x64_epi64 ( b, _MM_SHUFFLE ( 2, 1, 0, 0 ) );

    __m256i c = _mm256_srli_epi64 ( a, 64 - n );
    __m256i e = _mm256_permute4x64_epi64 ( c, _MM_SHUFFLE ( 1, 0, 0, 0 ) );

    __m256i f = _mm256_blend_epi32 ( _mm256_setzero_si256 ( ), d, _MM_SHUFFLE ( 3, 3, 3, 0 ) );
    __m256i g = _mm256_blend_epi32 ( _mm256_setzero_si256 ( ), e, _MM_SHUFFLE ( 3, 3, 0, 0 ) ); // 6

    return _mm256_or_si256 ( f, g );
}

__m256i left_shift_128_191 ( __m256i a, int n ) { // 7

    __m256i b = _mm256_slli_epi64 ( a, n );
    __m256i d = _mm256_permute4x64_epi64 ( b, _MM_SHUFFLE ( 1, 0, 0, 0 ) );

    __m256i c = _mm256_srli_epi64 ( a, 64 - n );
    __m256i e = _mm256_permute4x64_epi64 ( c, _MM_SHUFFLE ( 1, 0, 0, 0 ) );

    __m256i f = _mm256_blend_epi32 ( _mm256_setzero_si256 ( ), d, _MM_SHUFFLE ( 3, 3, 0, 0 ) );
    __m256i g = _mm256_blend_epi32 ( _mm256_setzero_si256 ( ), e, _MM_SHUFFLE ( 3, 0, 0, 0 ) );

    return _mm256_or_si256 ( f, g );
}

__m256i left_shift_192_255 ( __m256i a, int n ) { // 5

    return _mm256_blend_epi32 ( _mm256_setzero_si256 ( ), _mm256_slli_epi64 ( _mm256_permute4x64_epi64 ( a, _MM_SHUFFLE ( 0, 0, 0, 0 ) ), n ), _MM_SHUFFLE ( 3, 0, 0, 0 ) );
}

__m256i _mm256_sli_si256 ( __m256i a, int n ) {

    if ( n < 128 ) return n <  64 ? left_shift_000_063 ( a, n ) : left_shift_064_127 ( a, n % 64 );
    else           return n < 192 ? left_shift_128_191 ( a, n % 64 ) : left_shift_192_255 ( a, n % 64 );
}


__m256i right_shift_000_063 ( __m256i a, int n ) { // 6

    return _mm256_or_si256 ( _mm256_srli_epi64 ( a, n ), _mm256_blend_epi32 ( _mm256_setzero_si256 ( ), _mm256_permute4x64_epi64 ( _mm256_slli_epi64 ( a, 64 - n ), _MM_SHUFFLE ( 0, 3, 2, 1 ) ), _MM_SHUFFLE ( 0, 3, 3, 3 ) ) );
}

__m256i right_shift_064_127 ( __m256i a, int n ) { // 7

    __m256i b = _mm256_srli_epi64 ( a, n );
    __m256i d = _mm256_permute4x64_epi64 ( b, _MM_SHUFFLE ( 3, 3, 2, 1 ) );

    __m256i c = _mm256_slli_epi64 ( a, 64 - n );
    __m256i e = _mm256_permute4x64_epi64 ( c, _MM_SHUFFLE ( 3, 3, 3, 2 ) );

    __m256i f = _mm256_blend_epi32 ( _mm256_setzero_si256 ( ), d, _MM_SHUFFLE ( 0, 3, 3, 3 ) );
    __m256i g = _mm256_blend_epi32 ( _mm256_setzero_si256 ( ), e, _MM_SHUFFLE ( 0, 0, 3, 3 ) );

    return _mm256_or_si256 ( f, g );
}

__m256i right_shift_128_191 ( __m256i a, int n ) { // 7

    __m256i b = _mm256_srli_epi64 ( a, n );
    __m256i d = _mm256_permute4x64_epi64 ( b, _MM_SHUFFLE ( 3, 2, 3, 2 ) );

    __m256i c = _mm256_slli_epi64 ( a, 64 - n );
    __m256i e = _mm256_permute4x64_epi64 ( c, _MM_SHUFFLE ( 3, 2, 1, 3 ) );

    __m256i f = _mm256_blend_epi32 ( _mm256_setzero_si256 ( ), d, _MM_SHUFFLE ( 0, 0, 3, 3 ) );
    __m256i g = _mm256_blend_epi32 ( _mm256_setzero_si256 ( ), e, _MM_SHUFFLE ( 0, 0, 0, 3 ) );

    return _mm256_or_si256 ( f, g );
}

__m256i right_shift_192_255 ( __m256i a, int n ) { // 5

    return _mm256_blend_epi32 ( _mm256_setzero_si256 ( ), _mm256_srli_epi64 ( _mm256_permute4x64_epi64 ( a, _MM_SHUFFLE ( 0, 0, 0, 3 ) ), n ), _MM_SHUFFLE ( 0, 0, 0, 3 ) );
}

__m256i _mm256_sri_si256 ( __m256i a, int n ) {

    if ( n < 128 ) return n <  64 ? right_shift_000_063 ( a, n ) : right_shift_064_127 ( a, n % 64 );
    else           return n < 192 ? right_shift_128_191 ( a, n % 64 ) : right_shift_192_255 ( a, n % 64 );
}


__m256i left_rotate_000_063 ( __m256i a, int n ) { // 5

    return _mm256_or_si256 ( _mm256_slli_epi64 ( a, n ), _mm256_permute4x64_epi64 ( _mm256_srli_epi64 ( a, 64 - n ), _MM_SHUFFLE ( 2, 1, 0, 3 ) ) );
}

__m256i left_rotate_064_127 ( __m256i a, int n ) { // 6

    __m256i b = _mm256_slli_epi64 ( a, n );
    __m256i c = _mm256_srli_epi64 ( a, 64 - n );

    __m256i d = _mm256_permute4x64_epi64 ( b, _MM_SHUFFLE ( 2, 1, 0, 3 ) );
    __m256i e = _mm256_permute4x64_epi64 ( c, _MM_SHUFFLE ( 1, 0, 3, 2 ) );

    return _mm256_or_si256 ( d, e );
}

__m256i left_rotate_128_191 ( __m256i a, int n ) { // 6

    __m256i b = _mm256_slli_epi64 ( a, n );
    __m256i c = _mm256_srli_epi64 ( a, 64 - n );

    __m256i d = _mm256_permute4x64_epi64 ( b, _MM_SHUFFLE ( 1, 0, 3, 2 ) );
    __m256i e = _mm256_permute4x64_epi64 ( c, _MM_SHUFFLE ( 0, 3, 2, 1 ) );

    return _mm256_or_si256 ( d, e );
}

__m256i left_rotate_192_255 ( __m256i a, int n ) { // 5

    return _mm256_or_si256 ( _mm256_srli_epi64 ( a, 64 - n ), _mm256_permute4x64_epi64 ( _mm256_slli_epi64 ( a, n ), _MM_SHUFFLE ( 0, 3, 2, 1 ) ) );
}

__m256i _mm256_rli_si256 ( __m256i a, int n ) {

    if ( n < 128 ) return n <  64 ? left_rotate_000_063 ( a, n ) : left_rotate_064_127 ( a, n % 64 );
    else           return n < 192 ? left_rotate_128_191 ( a, n % 64 ) : left_rotate_192_255 ( a, n % 64 );
}


__m256i right_rotate_000_063 ( __m256i a, int n ) { // 5

    return _mm256_or_si256 ( _mm256_srli_epi64 ( a, n ), _mm256_permute4x64_epi64 ( _mm256_slli_epi64 ( a, 64 - n ), _MM_SHUFFLE ( 0, 3, 2, 1 ) ) );
}

__m256i right_rotate_064_127 ( __m256i a, int n ) { // 6

    __m256i b = _mm256_srli_epi64 ( a, n );
    __m256i c = _mm256_slli_epi64 ( a, 64 - n );

    __m256i d = _mm256_permute4x64_epi64 ( b, _MM_SHUFFLE ( 0, 3, 2, 1 ) );
    __m256i e = _mm256_permute4x64_epi64 ( c, _MM_SHUFFLE ( 1, 0, 3, 2 ) );

    return _mm256_or_si256 ( d, e );
}

__m256i right_rotate_128_191 ( __m256i a, int n ) { // 6

    __m256i b = _mm256_srli_epi64 ( a, n );
    __m256i c = _mm256_slli_epi64 ( a, 64 - n );

    __m256i d = _mm256_permute4x64_epi64 ( b, _MM_SHUFFLE ( 1, 0, 3, 2 ) );
    __m256i e = _mm256_permute4x64_epi64 ( c, _MM_SHUFFLE ( 2, 1, 0, 3 ) );

    return _mm256_or_si256 ( d, e );
}
__m256i right_rotate_192_255 ( __m256i a, int n ) { // 5

    return _mm256_or_si256 ( _mm256_slli_epi64 ( a, 64 - n ), _mm256_permute4x64_epi64 ( _mm256_srli_epi64 ( a, n ), _MM_SHUFFLE ( 2, 1, 0, 3 ) ) );
}

__m256i _mm256_rri_si256 ( __m256i a, int n ) {

    if ( n < 128 ) return n <  64 ? right_rotate_000_063 ( a, n      ) : right_rotate_064_127 ( a, n % 64 );
    else           return n < 192 ? right_rotate_128_191 ( a, n % 64 ) : right_rotate_192_255 ( a, n % 64 );
}

I have tried to make the _mm256_permute4x64_epi64 ops (when there in any case have to be two) to partially overlap, which should keep the overall latency to a minimum.

Most of the suggestions and or clues given by commenters were helpful in putting together the code, thanks to those. Obviously, improvements and or any other comments are welcome.

I think that Mystical's answer is interesting, but too complicated to be used effectively for generalized shifting/rotating for use f.e. in a library.

Upvotes: 3

Mysticial
Mysticial

Reputation: 471499

Probably not the kind of answer that you're expecting. But here's a reasonably efficient solution that actually works for a run-time shift amount.

The costs are:

  • Preprocess: ~12 - 14 instructions
  • Rotation: 5 instructions
  • Shift: 6 instructions

In order to shift or rotate anything, you must first preprocess the shift amount. Once you have that, you can efficiently perform shifts/rotations.

Because the preprocessing step is so expensive, this solution utilizes an object to hold the preprocessed shift amount so that it can be reused many times when shifting by the same amount.

For efficiency, the object should be on the stack in the same scope as the code that does the shifting. This allows the compiler to promote all the fields of the object into registers. Furthermore, it's recommended to force-inline all the methods of the class.

#include <stdint.h>
#include <immintrin.h>

class LeftShifter_AVX2{
public:
    LeftShifter_AVX2(uint32_t bits){
        //  Precompute all the necessary values.
        permL = _mm256_sub_epi32(
            _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7),
            _mm256_set1_epi32(bits / 32)
        );
        permR = _mm256_sub_epi32(permL, _mm256_set1_epi32(1));

        bits %= 32;
        shiftL = _mm_cvtsi32_si128(bits);
        shiftR = _mm_cvtsi32_si128(32 - bits);
        __m256i maskL = _mm256_cmpgt_epi32(_mm256_setzero_si256(), permL);
        __m256i maskR = _mm256_cmpgt_epi32(_mm256_setzero_si256(), permR);
        mask = _mm256_or_si256(maskL, _mm256_srl_epi32(maskR, shiftR));
    }

    __m256i rotate(__m256i x) const{
        __m256i L = _mm256_permutevar8x32_epi32(x, permL);
        __m256i R = _mm256_permutevar8x32_epi32(x, permR);
        L = _mm256_sll_epi32(L, shiftL);
        R = _mm256_srl_epi32(R, shiftR);
        return _mm256_or_si256(L, R);
    }
    __m256i shift(__m256i x) const{
        return _mm256_andnot_si256(mask, rotate(x));
    }

private:
    __m256i permL;
    __m256i permR;
    __m128i shiftL;
    __m128i shiftR;
    __m256i mask;
};

Test Program:

#include <iostream>
using namespace std;

void print_u8(__m256i x){
    union{
        __m256i v;
        uint8_t s[32];
    };
    v = x;
    for (int c = 0; c < 32; c++){
        cout << (int)s[c] << " ";
    }
    cout << endl;
}

int main(){
    union{
        __m256i x;
        char buffer[32];
    };
    for (int c = 0; c < 32; c++){
        buffer[c] = (char)c;
    }
    print_u8(x);
    print_u8(LeftShifter_AVX2(0).shift(x));
    print_u8(LeftShifter_AVX2(8).shift(x));
    print_u8(LeftShifter_AVX2(32).shift(x));
    print_u8(LeftShifter_AVX2(40).shift(x));
}

Output:

0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 
0 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 
0 0 0 0 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 
0 0 0 0 0 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26

Right-shift is very similar. I'll leave that as an exercise for the reader.

Upvotes: 9

Related Questions