bickit
bickit

Reputation: 51

Transposing 8x8 float matrix using NEON intrinsics

I have a program that needs to run a transpose operation on 8x8 float32 matrices many times. I want to transpose these using NEON SIMD intrinsics. I know that the array will always contain 8x8 float elements. I have a baseline non-intrinsic solution below:

void transpose(float *matrix, float *matrixT) {
    for (int i = 0; i < 8; i++) {
        for (int j = 0; j < 8; j++) {
            matrixT[i*8+j] = matrix[j*8+i];
        }
    }
}

I also created an intrinsic solution that transposes each 4x4 quadrant of the 8x8 matrix, and swaps the positions of the second and third quadrants. This solution looks like this:

void transpose_4x4(float *matrix, float *matrixT, int store_index) {
    float32x4_t r0, r1, r2, r3, c0, c1, c2, c3;
    r0 = vld1q_f32(matrix);
    r1 = vld1q_f32(matrix + 8);
    r2 = vld1q_f32(matrix + 16);
    r3 = vld1q_f32(matrix + 24);

    c0 = vzip1q_f32(r0, r1);
    c1 = vzip2q_f32(r0, r1);
    c2 = vzip1q_f32(r2, r3);
    c3 = vzip2q_f32(r2, r3);

    r0 = vcombine_f32(vget_low_f32(c0), vget_low_f32(c2));
    r1 = vcombine_f32(vget_high_f32(c0), vget_high_f32(c2));
    r2 = vcombine_f32(vget_low_f32(c1), vget_low_f32(c3));
    r3 = vcombine_f32(vget_high_f32(c1), vget_high_f32(c3));

    vst1q_f32(matrixT + store_index, r0);
    vst1q_f32(matrixT + store_index + 8, r1);
    vst1q_f32(matrixT + store_index + 16, r2);
    vst1q_f32(matrixT + store_index + 24, r3);
}

void transpose(float *matrix, float *matrixT) {
    // Transpose top-left 4x4 quadrant and store the result in the top-left 4x4 quadrant
    transpose_4x4(matrix, matrixT, 0);

    // Transpose top-right 4x4 quadrant and store the result in the bottom-left 4x4 quadrant
    transpose_4x4(matrix + 4, matrixT, 32);

    // Transpose bottom-left 4x4 quadrant and store the result in the top-right 4x4 quadrant
    transpose_4x4(matrix + 32, matrixT, 4);

    // Transpose bottom-right 4x4 quadrant and store the result in the bottom-right 4x4 quadrant
    transpose_4x4(matrix + 36, matrixT, 36);
}

This solution however, results in a slower performance than the baseline non-intrinsic solution. I am struggling to see, if there is one, a faster solution that can transpose my 8x8 matrix. Any help would be greatly appreciated!

Edit: both solutions are compiled using the -O1 flag.

Upvotes: 4

Views: 1120

Answers (2)

Aki Suihkonen
Aki Suihkonen

Reputation: 20037

It's possible to optimise the 8x8 neon code presented in the other answer; 8x8 transpose can be not only thought of as recursive version of [A B;C D]' == [A' C'; B' D'] but also as repeated application of zip or unzip.

  a b c d  
  e f g h 
  i j k l
  m n o p  == a b c d e f g h i j k l m n o p

  zip(first_half, last_half) ==
  zip(...) == a i b j c k d l e m f n g o h p
  zip(...) == a e i m b f j n c g k o d h l p == transpose

For 8x8 matrix we need to apply this algorithm 3 times and reading the data by vld4 two of those passes have been already done.

   float32x4x4_t d0 = vld4q_f32(input);
   float32x4x4_t d1 = vld4q_f32(input + 16);
   float32x4x4_t d2 = vld4q_f32(input + 32);
   float32x4x4_t d3 = vld4q_f32(input + 48);
   float32x4x4_t e0 = {
       vzipq_f32(d0.val[0], d2.val[0]).val[0],
       vzipq_f32(d0.val[1], d2.val[1]).val[0],
       vzipq_f32(d0.val[2], d2.val[2]).val[0],
       vzipq_f32(d0.val[3], d2.val[3]).val[0]
   };
   float32x4x4_t e1 = {
       vzipq_f32(d1.val[0], d3.val[0]).val[0],
       vzipq_f32(d1.val[1], d3.val[1]).val[0],
       vzipq_f32(d1.val[2], d3.val[2]).val[0],
       vzipq_f32(d1.val[3], d3.val[3]).val[0]
   };
   float32x4x4_t e2 = {
       vzipq_f32(d0.val[0], d2.val[0]).val[1],
       vzipq_f32(d0.val[1], d2.val[1]).val[1],
       vzipq_f32(d0.val[2], d2.val[2]).val[1],
       vzipq_f32(d0.val[3], d2.val[3]).val[1]
   };
   float32x4x4_t e3 = {
       vzipq_f32(d1.val[0], d3.val[0]).val[1],
       vzipq_f32(d1.val[1], d3.val[1]).val[1],
       vzipq_f32(d1.val[2], d3.val[2]).val[1],
       vzipq_f32(d1.val[3], d3.val[3]).val[1]
   };
   vst1q_f32_x4(output, e0);
   vst1q_f32_x4(output + 16, e1);
   vst1q_f32_x4(output + 32, e2);
   vst1q_f32_x4(output + 48, e3);

One should be able to perform the transpose also by starting with vld1q_f32_x4, then uzpq and finish with vst4q_f32.

Upvotes: 1

First off, you shouldn't expect a huge performance boost to start with:

  • there is actually no computation
  • you are dealing with 32bit data, and thus, not much of bandwidth constraint.

to sum it up, just a little bit saving in bandwidth by vectorizing - that's all

As for the 4x4 transpose, you don't even need a separate function, but just a macro:

#define TRANSPOSE4x4(pSrc,pDst) vst1q_f32_x4(pDst,vld4q_f32(pSrc))

will do the job since NEON does the 4x4 transpose on the fly when you load the data with vld4.

But you should ask yourself at this point if your approach - transposing all the matrice prior to actual computation - is the right one if 4x4 transpose costs virtually nothing. This step could end up being a pure waste of computation and bandwidth. Optimization shouldn't be limited to the final step, but should be considered from the designing phase.

8x8 transpose is a different animal though:

void transpose8x8(float *pDst, float *pSrc)
    {
        float32x4_t row0a, row0b, row1a, row1b, row2a, row2b, row3a, row3b, row4a, row4b, row5a, row5b, row6a, row6b, row7a, row7b;
        float32x4_t r0a, r0b, r1a, r1b, r2a, r2b, r3a, r3b, r4a, r4b, r5a, r5b, r6a, r6b, r7a, r7b;

        row0a = vld1q_f32(pSrc);
        pSrc += 4;
        row0b = vld1q_f32(pSrc);
        pSrc += 4;
        row1a = vld1q_f32(pSrc);
        pSrc += 4;
        row1b = vld1q_f32(pSrc);
        pSrc += 4;
        row2a = vld1q_f32(pSrc);
        pSrc += 4;
        row2b = vld1q_f32(pSrc);
        pSrc += 4;
        row3a = vld1q_f32(pSrc);
        pSrc += 4;
        row3b = vld1q_f32(pSrc);
        pSrc += 4;
        row4a = vld1q_f32(pSrc);
        pSrc += 4;
        row4b = vld1q_f32(pSrc);
        pSrc += 4;
        row5a = vld1q_f32(pSrc);
        pSrc += 4;
        row5b = vld1q_f32(pSrc);
        pSrc += 4;
        row6a = vld1q_f32(pSrc);
        pSrc += 4;
        row6b = vld1q_f32(pSrc);
        pSrc += 4;
        row7a = vld1q_f32(pSrc);
        pSrc += 4;
        row7b = vld1q_f32(pSrc);

        r0a = vtrn1q_f32(row0a, row1a);
        r0b = vtrn1q_f32(row0b, row1b);
        r1a = vtrn2q_f32(row0a, row1a);
        r1b = vtrn2q_f32(row0b, row1b);
        r2a = vtrn1q_f32(row2a, row3a);
        r2b = vtrn1q_f32(row2b, row3b);
        r3a = vtrn2q_f32(row2a, row3a);
        r3b = vtrn2q_f32(row2b, row3b);
        r4a = vtrn1q_f32(row4a, row5a);
        r4b = vtrn1q_f32(row4b, row5b);
        r5a = vtrn2q_f32(row4a, row5a);
        r5b = vtrn2q_f32(row4b, row5b);
        r6a = vtrn1q_f32(row6a, row7a);
        r6b = vtrn1q_f32(row6b, row7b);
        r7a = vtrn2q_f32(row6a, row7a);
        r7b = vtrn2q_f32(row6b, row7b);

        row0a = vtrn1q_f64(row0a, row2a);
        row0b = vtrn1q_f64(row0b, row2b);
        row1a = vtrn1q_f64(row1a, row3a);
        row1b = vtrn1q_f64(row1b, row3b);
        row2a = vtrn2q_f64(row0a, row2a);
        row2b = vtrn2q_f64(row0b, row2b);
        row3a = vtrn2q_f64(row1a, row3a);
        row3b = vtrn2q_f64(row1b, row3b);
        row4a = vtrn1q_f64(row4a, row6a);
        row4b = vtrn1q_f64(row4b, row6b);
        row5a = vtrn1q_f64(row5a, row7a);
        row5b = vtrn1q_f64(row5b, row7b);
        row6a = vtrn2q_f64(row4a, row6a);
        row6b = vtrn2q_f64(row4b, row6b);
        row7a = vtrn2q_f64(row5a, row7a);
        row7b = vtrn2q_f64(row5b, row7b);

        vst1q_f32(pDst, row0a);
        pDst += 4;
        vst1q_f32(pDst, row4a);
        pDst += 4;
        vst1q_f32(pDst, row1a);
        pDst += 4;
        vst1q_f32(pDst, row5a);
        pDst += 4;
        vst1q_f32(pDst, row2a);
        pDst += 4;
        vst1q_f32(pDst, row6a);
        pDst += 4;
        vst1q_f32(pDst, row3a);
        pDst += 4;
        vst1q_f32(pDst, row7a);
        pDst += 4;
        vst1q_f32(pDst, row0b);
        pDst += 4;
        vst1q_f32(pDst, row4b);
        pDst += 4;
        vst1q_f32(pDst, row1b);
        pDst += 4;
        vst1q_f32(pDst, row5b);
        pDst += 4;
        vst1q_f32(pDst, row2b);
        pDst += 4;
        vst1q_f32(pDst, row6b);
        pDst += 4;
        vst1q_f32(pDst, row3b);
        pDst += 4;
        vst1q_f32(pDst, row7b);

    }

It boils down to : 16 load + 32 trn + 16 store vs 64 load + 64 store

Now we can clearly see it really isn't worth it. The neon routine above might be a little faster, but I doubt it will make a difference in the end.

No, you can't optimize it any further. Nobody can. Just make sure the pointers are 64byte aligned, test it, and decide for yourself.

ld1     {v0.4s-v3.4s}, [x1], #64
ld1     {v4.4s-v7.4s}, [x1], #64
ld1     {v16.4s-v19.4s}, [x1], #64
ld1     {v20.4s-v23.4s}, [x1]

trn1    v24.4s, v0.4s, v2.4s    // row0
trn1    v25.4s, v1.4s, v3.4s
trn2    v26.4s, v0.4s, v2.4s    // row1
trn2    v27.4s, v1.4s, v3.4s
trn1    v28.4s, v4.4s, v6.4s    // row2
trn1    v29.4s, v5.4s, v7.4s
trn2    v30.4s, v4.4s, v6.4s    // row3
trn2    v31.4s, v5.4s, v7.4s
trn1    v0.4s, v16.4s, v18.4s   // row4
trn1    v1.4s, v17.4s, v19.4s
trn2    v2.4s, v16.4s, v18.4s   // row5
trn2    v3.4s, v17.4s, v19.4s
trn1    v4.4s, v20.4s, v22.4s   // row6
trn1    v5.4s, v21.4s, v23.4s
trn2    v6.4s, v20.4s, v22.4s   // row7
trn2    v7.4s, v21.4s, v23.4s

trn1    v16.2d, v24.2d, v28.2d  // row0a
trn1    v17.2d, v0.2d, v4.2d    // row0b
trn1    v18.2d, v26.2d, v30.2d  // row1a
trn1    v19.2d, v2.2d, v6.2d    // row1b
trn2    v20.2d, v24.2d, v28.2d  // row2a
trn2    v21.2d, v0.2d, v4.2d    // row2b
trn2    v22.2d, v26.2d, v30.2d  // row3a
trn2    v23.2d, v2.2d, v6.2d    // row3b

st1     {v16.4s-v19.4s}, [x0], #64
st1     {v20.4s-v23.4s}, [x0], #64

trn1    v16.2d, v25.2d, v29.2d  // row4a
trn1    v17.2d, v1.2d, v5.2d    // row4b
trn1    v18.2d, v27.2d, v31.2d  // row5a
trn1    v19.2d, v3.2d, v7.2d    // row5b
trn2    v20.2d, v25.2d, v29.2d  // row4a
trn2    v21.2d, v1.2d, v5.2d    // row4b
trn2    v22.2d, v27.2d, v31.2d  // row5a
trn2    v23.2d, v3.2d, v7.2d    // row5b

st1     {v16.4s-v19.4s}, [x0], #64
st1     {v20.4s-v23.4s}, [x0]

ret

above is the hand optimized assembly version that's most probably shorter (as short as it can get), but not exactly meaningfully faster than:

Below is the pure C version that I'd settle with:

void transpose8x8(float *pDst, float *pSrc)
{
    uint32_t i = 8;
    do {
        pDst[0] = *pSrc++;
        pDst[8] = *pSrc++;
        pDst[16] = *pSrc++;
        pDst[24] = *pSrc++;
        pDst[32] = *pSrc++;
        pDst[40] = *pSrc++;
        pDst[48] = *pSrc++;
        pDst[56] = *pSrc++;
        pDst++;            
    } while (--i);
}

or

void transpose8x8(float *pDst, float *pSrc)
{
    uint32_t i = 8;
    do {
        *pDst++ = pSrc[0];
        *pDst++ = pSrc[8];
        *pDst++ = pSrc[16];
        *pDst++ = pSrc[24];
        *pDst++ = pSrc[32];
        *pDst++ = pSrc[40];
        *pDst++ = pSrc[48];
        *pDst++ = pSrc[56];
        pSrc++;
    } while (--i);
}

PS: It could bring some gain in performance/power consumption if you declared pDst and pSrc uint32_t *, because the compiler would definitely generate pure integer machine code which has most various addressing modes, and only use w registers instead of s ones. Just typecase float * to uint32_t *

PS2: Clang already utilizes w registers instead of s ones while GCC is being GCC.... When will GNU-shills finally admit the fact that GCC is an extremely bad choice for ARM?
godbolt

PS3: Below is the non-neon version in assembly (zero latency) since I was very disappointed (even shocked) in both Clang and GCC above:

    .arch armv8-a
    .global transpose8x8
    .text

.balign 64
.func
transpose8x8:
    mov     w10, #8
    sub     x0, x0, #8
.balign 16
1:
    ldr     w2, [x1, #0]
    ldr     w3, [x1, #32]
    ldr     w4, [x1, #64]
    ldr     w5, [x1, #96]
    ldr     w6, [x1, #128]
    ldr     w7, [x1, #160]
    ldr     w8, [x1, #192]
    ldr     w9, [x1, #224]
    subs    w10, w10, #1
    stp     w2, w3, [x0, #8]
    add     x1, x1, #4
    stp     w4, w5, [x0, #16]
    stp     w6, w7, [x0, #24]
    stp     w8, w9, [x0, #32]!
    b.ne    1b
.balign 16
    ret
.endfunc
.end

It's arguably the best version you will ever get if you still insist on doing pure 8x8 transpose. It might be a little slower than the neon assembly version, but consume considerably less power.

Upvotes: 5

Related Questions