Reputation: 133
I've been trying to work out how we're supposed to scatter 16-bit integers using the scatter instructons in AVX512. What I have is 8 x 16-bit integers stored one in each of the 32-bit integers of an __m256i. I'd use a 256-bit equivalent of _mm512_i32extscatter_epi32, down converting _MM_DOWNCONV_EPI32_UINT16, but there is no such instruction and down-converting doesn't work on AVX512.
My understanding is this... we have to do 32-bit reads and writes, and we have to be careful of having two adjacent 16-bit writes trashing each other (if the same index in is in the index list twice then I don't need to worry about which one happens first). So we have to use a conflict gather scatter loop. In the loop we have to conflict on the 32-bit integer addresses, or the 16-bit indexes shifted left by 1 and used as indexes to the equivelant 32 bit array (the equiveland to casting the 16-bit array to a 32-bit array and then dividing the index by 2). Then we need to take a 32-bit integer we read and either alter the high 16-bits or the low 16-bits based on whether the original index into the 16-bit array was odd or even.
So here's what I get:
Work out if the indexes are odd or even and set a 2-bit mask of 01 or 10 accordingly, forming a 16-bit mask for 8 integers.
Turn the 16-bit integers into 32-bit integers by copying the low 16 bits into the high 16 bits
Turn the index into the array of 16-bit integers into an index into an array of 32-bit indexes by shifting right by one.
Use a confict loop with a mask
Masked-gather 32-bit integers
Use _mm256_mask_blend_epi16 to choose whether to alter the high or low 16-bits of the 32-bit integers just read (using the mask from (1)).
Masked-scatter back to memory
repeat until we have no conflicts in the un-written 32-bit integer addresses.
Please, is there a faster (or simpler) way to do this? And yes, I know, individual writes are faster - but this is about working out how to do it using AVX-512.
Here's the code:
void scatter(uint16_t *array, __m256i vindex, __m256i a)
{
__mmask16 odd = _mm256_test_epi16_mask(vindex, _mm256_set1_epi32(1));
__mmask16 even = ~odd & 0x5555;
__mmask16 odd_even = odd << 1 | even;
__m256i data = _mm256_mask_blend_epi16(0x5555, _mm256_bslli_epi128(a, 2), a);
__m256i word_locations = _mm256_srli_epi32(vindex, 1);
__mmask8 unwritten = 0xFF;
do
{
__m256i conflict = _mm256_maskz_conflict_epi32 (unwritten, word_locations);
conflict = _mm256_and_si256(_mm256_set1_epi32(unwritten), conflict);
__mmask8 mask = unwritten & _mm256_testn_epi32_mask(conflict, _mm256_set1_epi32(0xFFFF'FFFF));
__m256i was = _mm256_mmask_i32gather_epi32(_mm256_setzero_si256(), mask, word_locations, array, 4);
__m256i send = _mm256_mask_blend_epi16(odd_even, was, data);
_mm256_mask_i32scatter_epi32(array, mask, word_locations, send, 4);
unwritten ^= mask;
}
while (unwritten != 0);
}
Upvotes: 5
Views: 1437
Reputation: 1
I think the best way is straightforward implementation is using _mm_extract_epi16, which us basically makes it almost like individual writes:
void mm256_i16scatter_epi16(short* base_addr, __m256i indices, __m256i values, const int scale) {
// Assuming scale is a power of 2
__m128i indices_low = _mm256_extracti128_si256(indices, 0);
__m128i indices_high = _mm256_extracti128_si256(indices, 1);
__m128i values_low = _mm256_extracti128_si256(values, 0);
__m128i values_high = _mm256_extracti128_si256(values, 1);
base_addr[_mm_extract_epi16(indices_low, 0) * (scale / 2)] = _mm_extract_epi16(values_low, 0);
base_addr[_mm_extract_epi16(indices_low, 1) * (scale / 2)] = _mm_extract_epi16(values_low, 1);
base_addr[_mm_extract_epi16(indices_low, 2) * (scale / 2)] = _mm_extract_epi16(values_low, 2);
base_addr[_mm_extract_epi16(indices_low, 3) * (scale / 2)] = _mm_extract_epi16(values_low, 3);
base_addr[_mm_extract_epi16(indices_low, 4) * (scale / 2)] = _mm_extract_epi16(values_low, 4);
base_addr[_mm_extract_epi16(indices_low, 5) * (scale / 2)] = _mm_extract_epi16(values_low, 5);
base_addr[_mm_extract_epi16(indices_low, 6) * (scale / 2)] = _mm_extract_epi16(values_low, 6);
base_addr[_mm_extract_epi16(indices_low, 7) * (scale / 2)] = _mm_extract_epi16(values_low, 7);
base_addr[_mm_extract_epi16(indices_high, 0) * (scale / 2)] = _mm_extract_epi16(values_high, 0);
base_addr[_mm_extract_epi16(indices_high, 1) * (scale / 2)] = _mm_extract_epi16(values_high, 1);
base_addr[_mm_extract_epi16(indices_high, 2) * (scale / 2)] = _mm_extract_epi16(values_high, 2);
base_addr[_mm_extract_epi16(indices_high, 3) * (scale / 2)] = _mm_extract_epi16(values_high, 3);
base_addr[_mm_extract_epi16(indices_high, 4) * (scale / 2)] = _mm_extract_epi16(values_high, 4);
base_addr[_mm_extract_epi16(indices_high, 5) * (scale / 2)] = _mm_extract_epi16(values_high, 5);
base_addr[_mm_extract_epi16(indices_high, 6) * (scale / 2)] = _mm_extract_epi16(values_high, 6);
base_addr[_mm_extract_epi16(indices_high, 7) * (scale / 2)] = _mm_extract_epi16(values_high, 7);
}
Upvotes: 0
Reputation: 18827
If it is safe to read from/write to the two bytes after the last index, this should work as well:
void scatter2(uint16_t *array, __m256i vindex, __m256i a) {
__mmask8 odd = _mm256_test_epi32_mask(vindex, _mm256_set1_epi32(1));
int32_t* arr32 = (int32_t*)array;
__m256i was_odd = _mm256_i32gather_epi32(arr32, vindex, 2);
__m256i data_even = _mm256_mask_blend_epi16(0x5555, was_odd, a);
_mm256_mask_i32scatter_epi32(array, ~odd, vindex, data_even, 2);
__m256i was_even = _mm256_i32gather_epi32(arr32, vindex, 2);
__m256i data_odd = _mm256_mask_blend_epi16(0x5555, was_even, a);
_mm256_mask_i32scatter_epi32(array, odd, vindex, data_odd, 2);
}
If you could guarantee that indexes in vindex
are increasing (or at least for any partially conflicting {i
, i+1
} in vindex
i+1
comes after i
), you can probably get away with a single gather+blend+scatter. Also, it might be beneficial to use masked gathers (i.e., each time only gather the elements which you overwrite next) -- I'm not sure if this has an impact on throughput. Finally, the _mm256_mask_blend_epi16
could actually be replaced by a simple _mm256_blend_epi16
.
Upvotes: 1