user12722843
user12722843

Reputation:

Searching for the key using SIMD

I have the following struct, which stores keys and generic user-specified values:

typedef struct {
        uint32_t  len;
        uint32_t  cap;
        int32_t  *keys;
        void     *vals;
} dict;

now I want to create a function which iterates over the keys and returns corresponding value.

The non-SIMD version:

void*
dict_find(dict *d, int32_t k, size_t s) {
        size_t i;
        i = 0;

        while (i < d->len) {
                if (d->keys[i] == k) {
                        void *p;
                        p = (uint8_t*)d->vals + i * s;

                        return p;
                }

                ++i;
        }

        return NULL;
}

I tried to vectorize the snippet above and came up with this:

void*
dict_find_simd(dict *d, int32_t k, size_t s) {
        __m256i ymm0;
        ymm0 = _mm256_broadcastd_epi32(*(__m128i*)&k);

        __m256i  ymm1;
        uint32_t i;
        int      m;
        uint8_t  b;

        i = 0;
        while (i < d->len) { // [d->len] is aligned in 32 byte box.
                ymm1 = _mm256_load_si256((__m256i*)(d->keys + i));
                ymm1 = _mm256_cmpeq_epi32(ymm1, ymm0);

                m = _mm256_movemask_epi8(ymm1);
                b = __builtin_ctz(m) >> 2;

                i += (8 +  b * d->len); // Artificially break the loop. 
                                        // Remember [i] stores the modified value.
        }

        if (i <= d->len)
                return NULL;

        i -= (8 + b * d->len); // Restore the modified value.
        i += b;

        void *p;
        p = (uint8_t*)d->vals + i * s;

        return p;
}

The function seems to be working correctly (didn't test it much)?

BUT, there are 2 problems:

Upvotes: 2

Views: 484

Answers (1)

Andrey Semashev
Andrey Semashev

Reputation: 10614

I'm checking if the i > d->len then I return the pointer. The i can be overflowed and it will return NULL there. How can I solve this issue?

There are two ways you could handle the overflow (and a potential out-of-bound read caused by that).

  1. Use vector implementation only up to i that is dividable by the vector size, in the number of elements. If the vector loop didn't find the element, complete the tail processing in scalar code. This solution might be good if the input data is obtained from elsewhere, and there is no easy way to optimize memory allocation and initialization past the end of the buffer.

  2. Allow the read past the end of the buffer and make sure whatever is read there does not count as a valid (found) entry. Over-allocate the buffers to make sure you can always read a full vector worth of data. This is easy to do if you compare the resulting i against the number of elements in the container - if it is greater, then your algorithm "found" an element past the end and you should indicate that nothing is found. In some cases, this can come naturally from the nature of your data. For example, if you use a key value that will never be valid to fill past the end elements, or if your associated values can be used to the same effect (e.g. the past-the-end values are NULL pointers, which is also used to indicate the "not found" result).

You might noticed that I'm using a combination of _mm256_movemask_epi8 and __builtin_ctz in order to get the index of found key. Is there a better way (maybe a single instruction that does get the position of non zero value) to do this (without AVX512)?

I don't think there is a single instruction for this, but you could improve performance of this combination. Notice that you're comparing 32-bit values, meaning that _mm256_movemask_epi8 produces a mask for 8 elements (4 equal bits each). You could improve data density if you compared 4 pairs of vectors, then packed the results so that each byte in a vector corresponds to a distinct comparison result, and then apply one _mm256_movemask_epi8.

ymm1 = _mm256_load_si256((__m256i*)(d->keys + i));
ymm2 = _mm256_load_si256((__m256i*)(d->keys + i) + 1);
ymm3 = _mm256_load_si256((__m256i*)(d->keys + i) + 2);
ymm4 = _mm256_load_si256((__m256i*)(d->keys + i) + 3);

ymm1 = _mm256_cmpeq_epi32(ymm1, ymm0);
ymm2 = _mm256_cmpeq_epi32(ymm2, ymm0);
ymm3 = _mm256_cmpeq_epi32(ymm3, ymm0);
ymm4 = _mm256_cmpeq_epi32(ymm4, ymm0);

ymm1 = _mm256_packs_epi32(ymm1, ymm2);
ymm3 = _mm256_packs_epi32(ymm3, ymm4);
ymm1 = _mm256_packs_epi16(ymm1, ymm3);
ymm1 = _mm256_permute4x64_epi64(ymm1, _MM_SHUFFLE(3, 1, 2, 0));
ymm1 = _mm256_shuffle_epi32(ymm1, _MM_SHUFFLE(3, 1, 2, 0));

m = _mm256_movemask_epi8(ymm1);
if (m)
{
    b = __builtin_ctz(m); // no shift needed here
    break;
}

(Note that __builtin_ctz result is undefined if m is zero, but you could mitigate this upon exiting the loop, if you check if i is within bounds. But, as shown above, I would rather test m before __builtin_ctz and use it to shortcut the __builtin_ctz and as a sign to break the loop.)

The problem with this is that packing is done per 128-bit lane, which means you would have to shuffle the bytes between lanes before you can use the result. This, and the packing itself, adds overhead that may somewhat negate the benefits from this optimization. If you use 128-bit vectors, you can save the shuffling, and it may improve overall performance. I did not benchmark the code, you will have to do the testing.

Another possible optimization to think of is to shortcut the packing/shuffling and _mm256_movemask_epi8 if none of the comparisons are true. You could use _mm256_testz_si256 to check if all the comparison result vectors are zero and break out of the loop only when they are not.

ymm1 = _mm256_load_si256((__m256i*)(d->keys + i));
ymm2 = _mm256_load_si256((__m256i*)(d->keys + i) + 1);
ymm3 = _mm256_load_si256((__m256i*)(d->keys + i) + 2);
ymm4 = _mm256_load_si256((__m256i*)(d->keys + i) + 3);

ymm1 = _mm256_cmpeq_epi32(ymm1, ymm0);
ymm2 = _mm256_cmpeq_epi32(ymm2, ymm0);
ymm3 = _mm256_cmpeq_epi32(ymm3, ymm0);
ymm4 = _mm256_cmpeq_epi32(ymm4, ymm0);

ymm5 = _mm256_or_si256(ymm1, ymm2);
ymm6 = _mm256_or_si256(ymm3, ymm4);
ymm5 = _mm256_or_si256(ymm5, ymm6);

if (!_mm256_testz_si256(ymm5, ymm5))
{
    ymm1 = _mm256_packs_epi32(ymm1, ymm2);
    ymm3 = _mm256_packs_epi32(ymm3, ymm4);
    ymm1 = _mm256_packs_epi16(ymm1, ymm3);
    ymm1 = _mm256_permute4x64_epi64(ymm1, _MM_SHUFFLE(3, 1, 2, 0));
    ymm1 = _mm256_shuffle_epi32(ymm1, _MM_SHUFFLE(3, 1, 2, 0));

    m = _mm256_movemask_epi8(ymm1);
    b = __builtin_ctz(m);

    break;
}

Here, 3 OR operations are faster than 3 packs + 2 shuffles, so you might save some cycles if your data is large enough (i.e. if on average you're not expected to find the result in the initial elements). If you find the elements predominantly among the first elements then this would show worse performance than the loop without _mm256_testz_si256.


Here is an updated version of the above code based on suggestions by Peter Cordes in the comments.

ymm1 = _mm256_load_si256((__m256i*)(d->keys + i));
ymm2 = _mm256_load_si256((__m256i*)(d->keys + i) + 1);
ymm3 = _mm256_load_si256((__m256i*)(d->keys + i) + 2);
ymm4 = _mm256_load_si256((__m256i*)(d->keys + i) + 3);

ymm1 = _mm256_cmpeq_epi32(ymm1, ymm0);
ymm2 = _mm256_cmpeq_epi32(ymm2, ymm0);
ymm3 = _mm256_cmpeq_epi32(ymm3, ymm0);
ymm4 = _mm256_cmpeq_epi32(ymm4, ymm0);

ymm1 = _mm256_packs_epi32(ymm1, ymm2);
ymm3 = _mm256_packs_epi32(ymm3, ymm4);
ymm5 = _mm256_or_si256(ymm1, ymm3);  // cheap result to branch on 

if (_mm256_movemask_epi8(ymm5) != 0)
{
    ymm1 = _mm256_packs_epi16(ymm1, ymm3);     // now put the bits in order
    ymm1 = _mm256_permutevar8x32_epi32(ymm1,   // or vpermq + vpshufd like before
        _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7));

    m = _mm256_movemask_epi8(ymm1);
    b = __builtin_ctz(m);

    break;
}

The improvements were made having Skylake or similar microarchitecture in mind:

  1. Move two of the packs above the condition. They will be able to execute efficiently given that only two vpcmpeqd per cycle can execute, which is enough to feed one vpackssdw. Two vpcmpeqd per cycle is achievable given that two loads per cycle can be issued. In other words, the two pack instructions competing for port 5 do not become the bottleneck.

  2. vpmovmskb instruction is only one µop and has latency 2-3 cycles and vptest is two µops (3 cycles). The subsequent test will fuse with jz/jnz, so the condition on _mm256_movemask_epi8 can execute slightly faster. Note that at this point _mm256_movemask_epi8 is applied to a dummy vector ymm5, which is not used later to produce the correct result.

  3. The two shuffles in my code version can be replaced with one with a vector constant. Here, I'm using _mm256_setr_epi32 to initialize the constant, and decent compilers will convert it to an in-memory constant with no extra instructions. You may need to do this manually, if your compiler is not smart enough. Also, note that this constant is an additional memory access, which may get into play if your lookup tends to terminate early (i.e. if the code behind the condition significantly contributes to the total execution time of the algorithm). You could possibly mitigate this by loading the constant early, before entering the loop. The algorithm doesn't use many vector registers, so you must have plenty to spare to keep the constant loaded.

Upvotes: 2

Related Questions