Daniel
Daniel

Reputation: 8441

Emulating shifts on 64 bytes with AVX-512

My question is an extension of a previous question: Emulating shifts on 32 bytes with AVX.

How do I implement similar shifts on 64 bytes with AVX-512? Specifically, how should I implement:

Corrosponding to the SSE2 methods _mm_slli_si128 and _mm_srli_si128.

Upvotes: 3

Views: 1362

Answers (2)

Alexis Wilke
Alexis Wilke

Reputation: 20730

For those who need to shift by exactly 64 bits, you can use the permute instruction which is directly going to work in registers. For a shift by a multiple of 8 bits, you could use the byte shuffle (see VPSHUFB and look at the cast functions if you are dealing with floats as the shuffles uses integers).

Here is an example to shift by 64 bits ("SHR zmm1, 64"). The mask is used to clear the top 64 bits. If you want to ROR like functionality, you can use the version without the mask. Note that it's possible to do a shift to the left as well. Just change the indexes as required.

#include <immintrin.h>
#include <iostream>

void show(char const * msg, double *v)
{
    std::cout
            << msg
            << ": "
            << v[0]
            << " "
            << v[1]
            << " "
            << v[2]
            << " "
            << v[3]
            << " "
            << v[4]
            << " "
            << v[5]
            << " "
            << v[6]
            << " "
            << v[7]
            << "\n";
}


int main(int argc, char * argv[])
{
    double v[8] = { 1., 2., 3., 4., 5., 6., 7., 8. };
    double q[8] = {};
    alignas(64) std::uint64_t indexes[8] = { 1, 2, 3, 4, 5, 6, 7, 0 };

    show("init", v);
    show("q", q);

    // load
    __m512d a(_mm512_loadu_pd(v));
    __m512i i(_mm512_load_epi64(indexes));

    // shift
    //__m512d b(_mm512_permutex_pd(a, 0x39));   // can't cross between 4 low and 4 high with immediate
    //__m512d b(_mm512_permutexvar_pd(i, a));   // ROR
    __m512d b(_mm512_maskz_permutexvar_pd(0x7F, i, a));   // LSR on a double basis

    // store
    _mm512_storeu_pd(q, b);

    show("shifted", q);
    show("original", v);
}

Fully optimized output (-O3) reduces the whole shift to 3 instructions (which are intermingled with others in the output):

 96a:   62 f1 fd 48 6f 85 10    vmovdqa64 -0xf0(%rbp),%zmm0
 971:   ff ff ff 
 974:   b8 7f 00 00 00          mov    $0x7f,%eax              # mask
 979:   48 8d 3d 10 04 00 00    lea    0x410(%rip),%rdi        # d90 <_IO_stdin_used+0x10>
 980:   c5 f9 92 c8             kmovb  %eax,%k1                # special k1 register
 984:   4c 89 e6                mov    %r12,%rsi
 987:   62 f2 fd c9 16 85 d0    vpermpd -0x130(%rbp),%zmm0,%zmm0{%k1}{z}   # "shift"
 98e:   fe ff ff 
 991:   62 f1 fd 48 11 45 fe    vmovupd %zmm0,-0x80(%rbp)

In my case, I want to use that in a loop and the load (vmovdqa64) and store (vmovupd) are going to be before and after the loop, inside the loop, it will be really fast. (It needs to rotate that way 4,400 times before I need to save the result).

As pointed out by Peter, we can also use the valignq instruction:

// this is in place of the permute, without the need for the indexes
__m512i b(_mm512_maskz_alignr_epi64(0xFF, _mm512_castpd_si512(a), _mm512_castpd_si512(a), 1));

and the result is one instruction like so:

 979:   62 f1 fd 48 6f 85 d0    vmovdqa64 -0x130(%rbp),%zmm0
 980:   fe ff ff 
 983:   48 8d 75 80             lea    -0x80(%rbp),%rsi
 987:   48 8d 3d 02 04 00 00    lea    0x402(%rip),%rdi        # d90 <_IO_stdin_used+0x10>
 98e:   62 f3 fd 48 03 c0 01    valignq $0x1,%zmm0,%zmm0,%zmm0
 995:   62 f1 fd 48 11 45 fd    vmovupd %zmm0,-0xc0(%rbp)

An important point, using less registers is also much better since it increase our chances to get full optimizations 100% in registers instead of having to use memory (512 bits is a lot to transfer to and from memory).

Upvotes: 1

chtz
chtz

Reputation: 18807

Here is a working solution using a temporary array:

__m512i _mm512_slri_si512(__m512i a, size_t imm8)
{
    // set up temporary array and set upper half to zero 
    // (this needs to happen outside any critical loop)
    alignas(64) char temp[128];
    _mm512_store_si512(temp+64, _mm512_setzero_si512());

    // store input into lower half
    _mm512_store_si512(temp, a);

    // load shifted register
    return _mm512_loadu_si512(temp+imm8);
}

__m512i _mm512_slli_si512(__m512i a, size_t imm8)
{
    // set up temporary array and set lower half to zero 
    // (this needs to happen outside any critical loop)
    alignas(64) char temp[128];
    _mm512_store_si512(temp, _mm512_setzero_si512());

    // store input into upper half
    _mm512_store_si512(temp+64, a);

    // load shifted register
    return _mm512_loadu_si512(temp+(64-imm8));
}

This should also work if imm8 was not known at compile time, but it does not do any out-of-bounds checks. You could actually use a 3*64 temporary and share it between the left and right shift methods (and both would work for negative inputs as well).

Of course, if you share a temporary outside the function body, you must make sure that it is not accessed by multiple threads at once.

Godbolt-Link with usage demonstration: https://godbolt.org/z/LSgeWZ


As Peter noted, this store-load trick will cause a store-forwarding stall on all CPUs with AVX512. The most-efficient forwarding case (~6 cycle latency) only works when all the load bytes come from one store. If the load goes outside the most recent store that overlaps it at all, it has extra latency (like ~16 cycles) to scan the store buffer and if needed merge in bytes from L1d cache. See Can modern x86 implementations store-forward from more than one prior store? and Agner Fog's microarch guidefor more details. This extra-scanning process can probably be happening for multiple loads in parallel, and at least doesn't stall other things (like normal store-forwarding or the rest of the pipeline), so it may not be a throughput problem.

If you want many shift offsets of the same data, one store and multiple reloads at different alignments should be good.

But if latency is your primary issue you should try a solution based on valignd (also, if you want to shift by a multiple of 4 bytes that is obviously an easier solution). Or for constant shift-counts, a vector control for vpermw could work.


For completeness, here is a version based on valignd and valignr working for shifts from 0 to 64, known at compile-time (using C++17 -- but you can easily avoid the if constexpr this is only here because of the static_assert). Instead of shifting in zeros you can pass a second register (i.e., it behaves like valignr would behave if it would align across lanes).

template<int N>
__m512i shift_right(__m512i a, __m512i carry = _mm512_setzero_si512())
{
  static_assert(0 <= N && N <= 64);
  if constexpr(N   == 0) return a;
  if constexpr(N   ==64) return carry;
  if constexpr(N%4 == 0) return _mm512_alignr_epi32(carry, a, N / 4);
  else
  {
    __m512i a0 = shift_right< (N/16 + 1)*16>(a, carry);  // 16, 32, 48, 64
    __m512i a1 = shift_right< (N/16    )*16>(a, carry);  //  0, 16, 32, 48
    return _mm512_alignr_epi8(a0, a1, N % 16);
  }
}

template<int N>
__m512i shift_left(__m512i a, __m512i carry = _mm512_setzero_si512())
{
  return shift_right<64-N>(carry, a);
}

Here is a godbolt-link with some example assembly as well as output for every possible shift_right operation: https://godbolt.org/z/xmKJvA

GCC faithfully translates this into valignd and valignr instructions -- but may do an unnecessary vpxor instruction (e.g. in the shiftleft_49 example), Clang does some crazy substitutions (not sure if they actually make a difference, though).

The code could be extended to shift an arbitrary sequence of registers (always carrying bytes from the previous register).

Upvotes: 2

Related Questions