Maj mac
Maj mac

Reputation: 51

AVX(2)/SIMD way to get/set (to 1) a single bit in a 256 bit register

Current but hacky approach is this:

__m256i bitset(__m256i source, uint8_t index) {
    uint8_t pos_in_64 = index % 64;
    uint8_t location = index / 64;
    uint64_t bitmask = 1ULL << pos_in_64;

    __m256i mask = _mm256_setzero_si256();
    switch (location) {
    case 0: mask = _mm256_set_epi64x(0, 0, 0, bitmask); break;
    case 1: mask = _mm256_set_epi64x(0, 0, bitmask, 0); break;
    case 2: mask = _mm256_set_epi64x(0, bitmask, 0, 0); break;
    case 3: mask = _mm256_set_epi64x(bitmask, 0, 0, 0); break;
    }

    return _mm256_or_si256(source, mask);
}

bool bitget(__m256i source, uint8_t index) {
    uint8_t pos_in_64 = index % 64;
    uint8_t location = index / 64;
    uint64_t bitmask = 1ULL << pos_in_64;

    uint64_t extracted = 0;
    switch (location) {
    case 0: extracted = _mm256_extract_epi64(source, 0); break;
    case 1: extracted = _mm256_extract_epi64(source, 1); break;
    case 2: extracted = _mm256_extract_epi64(source, 2); break;
    case 3: extracted = _mm256_extract_epi64(source, 3); break;
    }
    return (extracted & bitmask) != 0;
}

But I'm sure there's a saner way to implement this without the switch statements and loads of scalar code

Upvotes: 3

Views: 191

Answers (1)

Soonts
Soonts

Reputation: 21956

I would do it like that.

#include <stdint.h>
#include <immintrin.h>
#include <stdlib.h>

namespace
{
    // Shift the bits of unsigned 64-bit integer "a" left by the number of bits specified in "shift", 
    // rotating the most-significant bit to the least-significant bit location
    inline uint64_t rotateLeft( uint64_t value, int shift )
    {
#if defined(_MSC_VER) || defined(__INTEL_COMPILER)
        // _rotl64 for MSVC and Intel compilers
        return _rotl64( value, shift );
#elif defined(__GNUC__)
        // __rolq for gcc compiler
        return __rolq( value, shift );
#elif defined(__clang__)
        // __builtin_rotateleft64 for clang compiler
        return __builtin_rotateleft64( value, shift );
#else
        // Fallback for weird compilers
        return ( value << shift ) | ( value >> ( 64 - shift ) );
#endif  
    }

    // Make int32 vector [ v, 0, 0, 0, U, U, U, U ] where U = undefined
    inline __m256i setLow( uint32_t v )
    {
        return _mm256_castsi128_si256( _mm_cvtsi32_si128( (int)v ) );
    }
}

inline __m256i makeSingleBitMask( uint8_t index )
{
    // Make a vector with single bit in the lowest int32 lane
    const __m256i vec32 = setLow( 1u << ( index % 32 ) );

    // Compute left shift amount
    // The formula is ( index / 32 ) * 8, saving one bit shift instruction there
    const uint8_t shift = ( index & (uint8_t)0b11100000 ) >> 2;

    // Create permutation vector for vpermd instruction
    const uint64_t perm8s = rotateLeft( 0x0101010101010100ull, shift );
    const __m128i perm8v = _mm_cvtsi64_si128( perm8s );
    const __m256i perm32 = _mm256_cvtepu8_epi32( perm8v );

    // Produce the vector
    return _mm256_permutevar8x32_epi32( vec32, perm32 );
}

__m256i bitSetAvx( __m256i source, uint8_t index )
{
    const __m256i oneBitVec = makeSingleBitMask( index );
    return _mm256_or_si256( source, oneBitVec );
}

bool bitTestAvx( __m256i source, uint8_t index )
{
    // Extract the 32-bit slice from the vector
    const __m256i perm32 = setLow( index / 32 );
    source = _mm256_permutevar8x32_epi32( source, perm32 );
    const uint32_t scalar = (uint32_t)_mm256_cvtsi256_si32( source );

    // Test the specified bit; compilers should hopefully use BT instruction
    // https://www.felixcloutier.com/x86/bt
    const uint32_t bit = 1u << ( index % 32 );
    return 0 != ( scalar & bit );
}

Take a look at the assembly output made by GCC 14 compiler when asked to optimize for AMD Zen 3. As you see, these two pages of C++ compiled into just 12 instructions for bit set and 10 for bit test, without any branches or memory transactions.

Upvotes: 2

Related Questions