user2454034
user2454034

Reputation: 101

How to optimize SIMD transpose function (8x4 => 4x8)?

I need to optimize the transpose of 8x4 and 4x8 float matrices with AVX. I use Agner Fog's vector class library.

The teal task - build BVH and sum min-max. Transposing is used in final stage of every loop (they also optimized by multi-thread, but tasks can be really much).

Code now looks like:

void transpose(register Vec4f (&fin)[8], register Vec8f (&mat)[4]) {
    for (int i = 0;i < 8;i++) {
        fin[i] = lookup<28>(Vec4i(0, 8, 16, 24) + i, (float *)mat);
    }
}

Needs variants of optimization. How to optimize this function for SIMD?


I recently wrote own variants of transpose (4x8 and 8x4) with vectorclass. Version 1.0.

void transpose(register Vec4f(&fin)[8], register Vec8f(&mat)[4]) {
    register Vec8f a00 = blend8f<0, 8, 1, 9, 2, 10, 3, 11>(mat[0], mat[1]);
    register Vec8f a10 = blend8f<0, 8, 1, 9, 2, 10, 3, 11>(mat[2], mat[3]);
    register Vec8f a01 = blend8f<4, 12, 5, 13, 6, 14, 7, 15>(mat[0], mat[1]);
    register Vec8f a11 = blend8f<4, 12, 5, 13, 6, 14, 7, 15>(mat[2], mat[3]);

    register Vec8f v0_1 = blend8f<0, 1, 8, 9, 2, 3, 10, 11>(a00, a10);
    register Vec8f v2_3 = blend8f<4, 5, 12, 13, 6, 7, 14, 15>(a00, a10);
    register Vec8f v4_5 = blend8f<0, 1, 8, 9, 2, 3, 10, 11>(a01, a11);
    register Vec8f v6_7 = blend8f<4, 5, 12, 13, 6, 7, 14, 15>(a01, a11);

    fin[0] = v0_1.get_low();
    fin[1] = v0_1.get_high();
    fin[2] = v2_3.get_low();
    fin[3] = v2_3.get_high();
    fin[4] = v4_5.get_low();
    fin[5] = v4_5.get_high();
    fin[6] = v6_7.get_low();
    fin[7] = v6_7.get_high();
}

void transpose(register Vec8f(&fin)[4], register Vec4f(&mat)[8]) {
    register Vec8f a0_1 = Vec8f(mat[0], mat[1]);
    register Vec8f a2_3 = Vec8f(mat[2], mat[3]);
    register Vec8f a4_5 = Vec8f(mat[4], mat[5]);
    register Vec8f a6_7 = Vec8f(mat[6], mat[7]);

    register Vec8f a00 = blend8f<0, 4, 8 , 12, 1, 5, 9 , 13>(a0_1, a2_3);
    register Vec8f a10 = blend8f<0, 4, 8 , 12, 1, 5, 9 , 13>(a4_5, a6_7);
    register Vec8f a01 = blend8f<2, 6, 10, 14, 3, 7, 11, 15>(a0_1, a2_3);
    register Vec8f a11 = blend8f<2, 6, 10, 14, 3, 7, 11, 15>(a4_5, a6_7);

    fin[0] = blend8f<0, 1, 2, 3, 8, 9, 10, 11>(a00, a10);
    fin[1] = blend8f<4, 5, 6, 7, 12, 13, 14, 15>(a00, a10);
    fin[2] = blend8f<0, 1, 2, 3, 8, 9, 10, 11>(a01, a11);
    fin[3] = blend8f<4, 5, 6, 7, 12, 13, 14, 15>(a01, a11);
}

Need version 2.0.

Upvotes: 4

Views: 1370

Answers (2)

Z boson
Z boson

Reputation: 33679

The Vector Class Library (VCL) uses template meta-programming to determine the best intrinsics for permuting and blending. However, when it comes to permuting and blending you often still need to know the limitations of the hardware to get the best results.

I converted Stgatilov's already excellent answer to use the VCL and it produces ideal assembly (eight shuffles). Here is the function:

void tran8x4_AVX(float *a, float *b) {
    Vec8f tmp0, tmp1, tmp2, tmp3;
    Vec8f row0, row1, row2, row3;

    row0 = Vec8f().load(&a[8*0]);
    row1 = Vec8f().load(&a[8*1]);
    row2 = Vec8f().load(&a[8*2]);
    row3 = Vec8f().load(&a[8*3]);    

    tmp0 = blend8f<0, 1,  8, 9,  4, 5, 12, 13>(row0, row1);
    tmp2 = blend8f<2, 3, 10, 11, 6, 7, 14, 15>(row0, row1);
    tmp1 = blend8f<0, 1,  8, 9,  4, 5, 12, 13>(row2, row3);
    tmp3 = blend8f<2, 3, 10, 11, 6, 7, 14, 15>(row2, row3);

    row0 = blend8f<0, 2, 8, 10, 4, 6, 12, 14>(tmp0, tmp1);
    row1 = blend8f<1, 3, 9, 11, 5, 7, 13, 15>(tmp0, tmp1);
    row2 = blend8f<0, 2, 8, 10, 4, 6, 12, 14>(tmp2, tmp3);
    row3 = blend8f<1, 3, 9, 11, 5, 7, 13, 15>(tmp2, tmp3);

    row0.get_low().store(&b[  4*0]);
    row1.get_low().store(&b[  4*1]);
    row2.get_low().store(&b[  4*2]);
    row3.get_low().store(&b[  4*3]);
    row0.get_high().store(&b[ 4*4]);
    row1.get_high().store(&b[ 4*5]);
    row2.get_high().store(&b[ 4*6]);
    row3.get_high().store(&b[ 4*7]);
}

Here is the assembly (g++ -S -O3 -mavx test.cpp)

    vmovups 32(%rdi), %ymm4
    vmovups 64(%rdi), %ymm3
    vmovups (%rdi), %ymm1
    vmovups 96(%rdi), %ymm0
    vshufps $68, %ymm4, %ymm1, %ymm2
    vshufps $68, %ymm0, %ymm3, %ymm5
    vshufps $238, %ymm4, %ymm1, %ymm1
    vshufps $238, %ymm0, %ymm3, %ymm0
    vshufps $136, %ymm5, %ymm2, %ymm4
    vshufps $221, %ymm5, %ymm2, %ymm2
    vshufps $136, %ymm0, %ymm1, %ymm3
    vshufps $221, %ymm0, %ymm1, %ymm0
    vmovups %xmm4, (%rsi)
    vextractf128    $0x1, %ymm4, %xmm4
    vmovups %xmm2, 16(%rsi)
    vextractf128    $0x1, %ymm2, %xmm2
    vmovups %xmm3, 32(%rsi)
    vextractf128    $0x1, %ymm3, %xmm3
    vmovups %xmm0, 48(%rsi)
    vextractf128    $0x1, %ymm0, %xmm0
    vmovups %xmm4, 64(%rsi)
    vmovups %xmm2, 80(%rsi)
    vmovups %xmm3, 96(%rsi)
    vmovups %xmm0, 112(%rsi)
    vzeroupper
    ret
    .cfi_endproc

Here is a full test

#include <stdio.h>
#include "vectorclass.h"

void tran8x4(float *a, float *b) {
    for(int i=0; i<4; i++) {
        for(int j=0; j<8; j++) {
            b[j*4+i] = a[i*8+j];
        }
    }
}

void tran8x4_AVX(float *a, float *b) {
    Vec8f tmp0, tmp1, tmp2, tmp3;
    Vec8f row0, row1, row2, row3;

    row0 = Vec8f().load(&a[8*0]);
    row1 = Vec8f().load(&a[8*1]);
    row2 = Vec8f().load(&a[8*2]);
    row3 = Vec8f().load(&a[8*3]);


    tmp0 = blend8f<0, 1, 8, 9, 4, 5, 12, 13>(row0, row1);
    tmp2 = blend8f<2, 3, 10, 11, 6, 7, 14, 15>(row0, row1);
    tmp1 = blend8f<0, 1, 8, 9, 4, 5, 12, 13>(row2, row3);
    tmp3 = blend8f<2, 3, 10, 11, 6, 7, 14, 15>(row2, row3);

    row0 = blend8f<0, 2, 8, 10, 4, 6, 12, 14>(tmp0, tmp1);
    row1 = blend8f<1, 3, 9, 11, 5, 7, 13, 15>(tmp0, tmp1);
    row2 = blend8f<0, 2, 8, 10, 4, 6, 12, 14>(tmp2, tmp3);
    row3 = blend8f<1, 3, 9, 11, 5, 7, 13, 15>(tmp2, tmp3);

    row0.get_low().store(&b[  4*0]);
    row1.get_low().store(&b[  4*1]);
    row2.get_low().store(&b[  4*2]);
    row3.get_low().store(&b[  4*3]);
    row0.get_high().store(&b[ 4*4]);
    row1.get_high().store(&b[ 4*5]);
    row2.get_high().store(&b[ 4*6]);
    row3.get_high().store(&b[ 4*7]);

}


int main() {
    float a[32], b1[32], b2[32];
    for(int i=0; i<32; i++) a[i] = i;
    for(int i=0; i<4; i++) {
        for(int j=0; j<8; j++) {
            printf("%2.0f ", a[i*8+j]);
        } puts("");
    }
    tran8x4(a,b1);
    tran8x4_AVX(a,b2);
    puts("");
    for(int i=0; i<8; i++) {
        for(int j=0; j<4; j++) {
            printf("%2.0f ", b1[i*4+j]);
        } puts("");
    }
    puts("");
    for(int i=0; i<8; i++) {
        for(int j=0; j<4; j++) {
            printf("%2.0f ", b2[i*4+j]);
        } puts("");
    }
}

Upvotes: 1

stgatilov
stgatilov

Reputation: 5533

I have no experience with the vectorclass library, but from briefly looking through the sources of the lookup template function, it seems that you are doing something terribly inefficient.

I propose a simple and efficient solution with SSE/AVX intrinsics below. I have no idea how to encode it fully in terms of vectorclass library. However, you can use conversion operator to extract raw data as __m128 and __m256 from classes Vec4f and Vec8f. Appropriate constructor allows you to convert the raw results back into vector classes.


In pure SSE with intrinsics, there is a macro _MM_TRANSPOSE4_PS in the header xmmintrin.h. It transposes 4x4 matrix of floats with each row in a separate 128-bit register. If you have only SSE (i.e. no AVX), then you should just call this macro twice and you are done. Here is the code:

#define _MM_TRANSPOSE4_PS(row0, row1, row2, row3) {    \
  __m128 tmp3, tmp2, tmp1, tmp0;                      \
  tmp0 = _mm_shuffle_ps(row0, row1, 0x44);            \
  tmp2 = _mm_shuffle_ps(row0, row1, 0xEE);            \
  tmp1 = _mm_shuffle_ps(row2, row3, 0x44);            \
  tmp3 = _mm_shuffle_ps(row2, row3, 0xEE);            \
  row0 = _mm_shuffle_ps(tmp0, tmp1, 0x88);            \
  row1 = _mm_shuffle_ps(tmp0, tmp1, 0xDD);            \
  row2 = _mm_shuffle_ps(tmp2, tmp3, 0x88);            \
  row3 = _mm_shuffle_ps(tmp2, tmp3, 0xDD);            \
}

In AVX, an instruction with 256-bit operands usually just does the SSE equivalent operation on two halves on the operands (called lanes). And intrinsic _mm256_shuffle_ps is not exception: it simply shuffles two 128-bit lanes as its _mm equivalent does. If means that if we change the _mm prefix to _mm256 prefix in the macro, it would transpose two 4x4 matrices: the one located in the lower lanes of the four 256-bit registers, and the one located in the upper lanes of the four 256-bit registers. We have only to break the resulting 256-bit registers into halves and order them properly.

The resulting code is presented below. I have checked that it works properly. It seems to have only 12 instructions, so I guess it would be quite fast.

void Transpose4x8(__m128 dst[8], __m256 src[4]) {
  __m256 row0 = src[0], row1 = src[1], row2 = src[2], row3 = src[3];
  __m256 tmp3, tmp2, tmp1, tmp0;
  tmp0 = _mm256_shuffle_ps(row0, row1, 0x44);
  tmp2 = _mm256_shuffle_ps(row0, row1, 0xEE);
  tmp1 = _mm256_shuffle_ps(row2, row3, 0x44);
  tmp3 = _mm256_shuffle_ps(row2, row3, 0xEE);
  row0 = _mm256_shuffle_ps(tmp0, tmp1, 0x88);
  row1 = _mm256_shuffle_ps(tmp0, tmp1, 0xDD);
  row2 = _mm256_shuffle_ps(tmp2, tmp3, 0x88);
  row3 = _mm256_shuffle_ps(tmp2, tmp3, 0xDD);
  dst[0] = _mm256_castps256_ps128(row0);
  dst[1] = _mm256_castps256_ps128(row1);
  dst[2] = _mm256_castps256_ps128(row2);
  dst[3] = _mm256_castps256_ps128(row3);
  dst[4] = _mm256_extractf128_ps(row0, 1);
  dst[5] = _mm256_extractf128_ps(row1, 1);
  dst[6] = _mm256_extractf128_ps(row2, 1);
  dst[7] = _mm256_extractf128_ps(row3, 1);
}

UPDATE Inverse transposition is done quite the same way, just some things go in reversed order. Here is the code:

void Transpose8x4(__m256 dst[4], __m128 src[8]) {
  __m256 row0 = _mm256_setr_m128(src[0], src[4]);
  __m256 row1 = _mm256_setr_m128(src[1], src[5]);
  __m256 row2 = _mm256_setr_m128(src[2], src[6]);
  __m256 row3 = _mm256_setr_m128(src[3], src[7]);
  __m256 tmp3, tmp2, tmp1, tmp0;
  tmp0 = _mm256_shuffle_ps(row0, row1, 0x44);
  tmp2 = _mm256_shuffle_ps(row0, row1, 0xEE);
  tmp1 = _mm256_shuffle_ps(row2, row3, 0x44);
  tmp3 = _mm256_shuffle_ps(row2, row3, 0xEE);
  row0 = _mm256_shuffle_ps(tmp0, tmp1, 0x88);
  row1 = _mm256_shuffle_ps(tmp0, tmp1, 0xDD);
  row2 = _mm256_shuffle_ps(tmp2, tmp3, 0x88);
  row3 = _mm256_shuffle_ps(tmp2, tmp3, 0xDD);
  dst[0] = row0; dst[1] = row1; dst[2] = row2; dst[3] = row3;
}

Upvotes: 4

Related Questions