Andrew Trotman
Andrew Trotman

Reputation: 133

Gather / Scatter 16-bit integers using AVX-512

I've been trying to work out how we're supposed to scatter 16-bit integers using the scatter instructons in AVX512. What I have is 8 x 16-bit integers stored one in each of the 32-bit integers of an __m256i. I'd use a 256-bit equivalent of _mm512_i32extscatter_epi32, down converting _MM_DOWNCONV_EPI32_UINT16, but there is no such instruction and down-converting doesn't work on AVX512.

My understanding is this... we have to do 32-bit reads and writes, and we have to be careful of having two adjacent 16-bit writes trashing each other (if the same index in is in the index list twice then I don't need to worry about which one happens first). So we have to use a conflict gather scatter loop. In the loop we have to conflict on the 32-bit integer addresses, or the 16-bit indexes shifted left by 1 and used as indexes to the equivelant 32 bit array (the equiveland to casting the 16-bit array to a 32-bit array and then dividing the index by 2). Then we need to take a 32-bit integer we read and either alter the high 16-bits or the low 16-bits based on whether the original index into the 16-bit array was odd or even.

So here's what I get:

  1. Work out if the indexes are odd or even and set a 2-bit mask of 01 or 10 accordingly, forming a 16-bit mask for 8 integers.

  2. Turn the 16-bit integers into 32-bit integers by copying the low 16 bits into the high 16 bits

  3. Turn the index into the array of 16-bit integers into an index into an array of 32-bit indexes by shifting right by one.

  4. Use a confict loop with a mask

  5. Masked-gather 32-bit integers

  6. Use _mm256_mask_blend_epi16 to choose whether to alter the high or low 16-bits of the 32-bit integers just read (using the mask from (1)).

  7. Masked-scatter back to memory

  8. repeat until we have no conflicts in the un-written 32-bit integer addresses.

Please, is there a faster (or simpler) way to do this? And yes, I know, individual writes are faster - but this is about working out how to do it using AVX-512.

Here's the code:

void scatter(uint16_t *array, __m256i vindex, __m256i a)
    {
    __mmask16 odd = _mm256_test_epi16_mask(vindex, _mm256_set1_epi32(1));
    __mmask16 even = ~odd & 0x5555;
    __mmask16 odd_even = odd << 1 | even;

    __m256i data = _mm256_mask_blend_epi16(0x5555, _mm256_bslli_epi128(a, 2), a);

    __m256i word_locations = _mm256_srli_epi32(vindex, 1);
    __mmask8 unwritten = 0xFF;
    do
        {
        __m256i conflict = _mm256_maskz_conflict_epi32 (unwritten, word_locations);
        conflict = _mm256_and_si256(_mm256_set1_epi32(unwritten), conflict);
        __mmask8 mask = unwritten & _mm256_testn_epi32_mask(conflict, _mm256_set1_epi32(0xFFFF'FFFF));

        __m256i was = _mm256_mmask_i32gather_epi32(_mm256_setzero_si256(), mask, word_locations, array, 4);
        __m256i send = _mm256_mask_blend_epi16(odd_even, was, data);
        _mm256_mask_i32scatter_epi32(array, mask, word_locations, send, 4);

        unwritten ^= mask;
        }
    while (unwritten != 0);
    }

Upvotes: 5

Views: 1437

Answers (2)

genesisviva
genesisviva

Reputation: 1

I think the best way is straightforward implementation is using _mm_extract_epi16, which us basically makes it almost like individual writes:

void mm256_i16scatter_epi16(short* base_addr, __m256i indices, __m256i values, const int scale) {
    // Assuming scale is a power of 2
    __m128i indices_low = _mm256_extracti128_si256(indices, 0);
    __m128i indices_high = _mm256_extracti128_si256(indices, 1);
    __m128i values_low = _mm256_extracti128_si256(values, 0);
    __m128i values_high = _mm256_extracti128_si256(values, 1);

    base_addr[_mm_extract_epi16(indices_low, 0) * (scale / 2)] = _mm_extract_epi16(values_low, 0);
    base_addr[_mm_extract_epi16(indices_low, 1) * (scale / 2)] = _mm_extract_epi16(values_low, 1);
    base_addr[_mm_extract_epi16(indices_low, 2) * (scale / 2)] = _mm_extract_epi16(values_low, 2);
    base_addr[_mm_extract_epi16(indices_low, 3) * (scale / 2)] = _mm_extract_epi16(values_low, 3);
    base_addr[_mm_extract_epi16(indices_low, 4) * (scale / 2)] = _mm_extract_epi16(values_low, 4);
    base_addr[_mm_extract_epi16(indices_low, 5) * (scale / 2)] = _mm_extract_epi16(values_low, 5);
    base_addr[_mm_extract_epi16(indices_low, 6) * (scale / 2)] = _mm_extract_epi16(values_low, 6);
    base_addr[_mm_extract_epi16(indices_low, 7) * (scale / 2)] = _mm_extract_epi16(values_low, 7);

    base_addr[_mm_extract_epi16(indices_high, 0) * (scale / 2)] = _mm_extract_epi16(values_high, 0);
    base_addr[_mm_extract_epi16(indices_high, 1) * (scale / 2)] = _mm_extract_epi16(values_high, 1);
    base_addr[_mm_extract_epi16(indices_high, 2) * (scale / 2)] = _mm_extract_epi16(values_high, 2);
    base_addr[_mm_extract_epi16(indices_high, 3) * (scale / 2)] = _mm_extract_epi16(values_high, 3);
    base_addr[_mm_extract_epi16(indices_high, 4) * (scale / 2)] = _mm_extract_epi16(values_high, 4);
    base_addr[_mm_extract_epi16(indices_high, 5) * (scale / 2)] = _mm_extract_epi16(values_high, 5);
    base_addr[_mm_extract_epi16(indices_high, 6) * (scale / 2)] = _mm_extract_epi16(values_high, 6);
    base_addr[_mm_extract_epi16(indices_high, 7) * (scale / 2)] = _mm_extract_epi16(values_high, 7);

}

Upvotes: 0

chtz
chtz

Reputation: 18827

If it is safe to read from/write to the two bytes after the last index, this should work as well:

void scatter2(uint16_t *array, __m256i vindex, __m256i a) {
  __mmask8 odd = _mm256_test_epi32_mask(vindex, _mm256_set1_epi32(1));

  int32_t* arr32 = (int32_t*)array;
  __m256i was_odd = _mm256_i32gather_epi32(arr32, vindex, 2);

  __m256i data_even = _mm256_mask_blend_epi16(0x5555, was_odd, a);
  _mm256_mask_i32scatter_epi32(array, ~odd, vindex, data_even, 2);
  __m256i was_even = _mm256_i32gather_epi32(arr32, vindex, 2);

  __m256i data_odd = _mm256_mask_blend_epi16(0x5555, was_even, a);
  _mm256_mask_i32scatter_epi32(array, odd, vindex, data_odd, 2);
}

If you could guarantee that indexes in vindex are increasing (or at least for any partially conflicting {i, i+1} in vindex i+1 comes after i), you can probably get away with a single gather+blend+scatter. Also, it might be beneficial to use masked gathers (i.e., each time only gather the elements which you overwrite next) -- I'm not sure if this has an impact on throughput. Finally, the _mm256_mask_blend_epi16 could actually be replaced by a simple _mm256_blend_epi16.

Upvotes: 1

Related Questions