sourabh jaiswal
sourabh jaiswal

Reputation: 1382

fast bit-matrix (64x64) transpose algorithm using SIMD (ARM)

I am trying to understand, if there is a fast way to do a matrix transpose (64x64 bits) using ARM SIMD instructions.

I tried to explore the VTRN instruction of ARM SIMD but am not sure of its effective application in this scenario.

The input matrix is represented as uint64 mat[64], and the output is supposed to be a bitwise transpose.

For example if the input is:

0000....
1111....
0000....
1111....

The expected output:

0101....
0101....
0101....
0101....

Upvotes: 5

Views: 1890

Answers (4)

swenson
swenson

Reputation: 56

I did some investigation and reading of the other answers and implemented a few different methods (in Rust and assembly), and benchmarked them on my Apple M2 Max.

tl;dr Using the NEON instructions gives you a 2× advantage over the non-NEON versions.

Results:

recursive:                 163 ns/iter (+/- 10)
Hacker's Delight loop:     157 ns/iter (+/- 21)
Hacker's Delight unrolled:  60 ns/iter (+/- 4)
Lees's answer:              37 ns/iter (+/- 3)
fuz's answer:               31 ns/iter (+/- 1)
  • The "recursive" algorithm is a basic recursive algorithm, slightly optimized using NEON ld4 instructions, etc., but involves much more memory movement.

  • The "Hacker's Delight loop" version is a pretty straightforward implementation of Hacker's Delight, Figure 7-3 (at least in the 1st edition), but modified to bit-reverse the 64-bit words to match the other answers. It looks something like:

     let mut j = 32;
     let mut m = 0xffffffffu64;
     while j != 0 {
         let mut k = 0;
         while k < 64 {
             unsafe {
                 let a = *x.get_unchecked(k + j);
                 let b = *x.get_unchecked(k);
                 let t = (a ^ (b >> j)) & m;
                 *x.get_unchecked_mut(k + j) = a ^ t;
                 *x.get_unchecked_mut(k) = b ^ (t << j);
             }
    
             k = (k + j + 1) & !j;
         }
         j >>= 1;
         m ^= m << j;
     }
    

    Rust/LLVM will auto-vectorize this a little, but not a lot (as of Rust 1.76.0).

  • The "Hacker's Delight unrolled" version is the same, but with the loops manually unrolled. I did not observe any auto-vectorization.

  • The "Lee's answer" version is adapted from his second answer, https://stackoverflow.com/a/71653601/10524 -- however, I could not get the code to output correct answers, which could be my fault in transcribing it, or in the original assembly, but I'm not going to spend any more time at the moment fixing it.

  • The "fuz's answer" version is adapted from https://stackoverflow.com/a/71555778/10524 -- this code does work.

There's probably a little more room for improvement in the Hacker's Delight unrolled version by switching to assembly/intrinsics for more efficient memory management, and possibly doing more efficient swaps using rotate instructions (as the book even talks about).

Lee's code outputs the transposed bits to a separate area of memory, which is one of the reasons it is slightly slower. If we don't do the copy-back, it is within the margin of error of fuz's answer. There's probably some room for improvement if we do the transform in-place.

fuz's answer seems to be the winner at the moment.

I've implemented these all in Rust (with inline assembly, where applicable) here: https://github.com/swenson/binary_matrix/tree/simd-transpose/src

Upvotes: 3

    .arch armv8-a
    .global transposeBitwise64x64_noob
    .text

.balign 64
.func
// void transposeBitwise64x64_noob(uint64_t *pDst, uint64_t *pSrc);
pDst    .req    x0
pSrc0   .req    x1
pSrc1   .req    x2
pDst1   .req    x3
stride  .req    x4
count   .req    w5

pDst0   .req    pDst

// no "battle plan" here. not needed for such a self-explanatory cakewalk
transposeBitwise64x64_noob:
    add     pSrc1, pSrc0, #64
    movi    v6.16b, #0xcc
    mov     stride, #128
    movi    v7.16b, #0xaa
    sub     pDst, pDst, #32
    mov     count, #2

.balign 16
1:
    ld4     {v16.16b, v17.16b, v18.16b, v19.16b}, [pSrc0], stride
    ld4     {v20.16b, v21.16b, v22.16b, v23.16b}, [pSrc1], stride
    ld4     {v24.16b, v25.16b, v26.16b, v27.16b}, [pSrc0], stride
    ld4     {v28.16b, v29.16b, v30.16b, v31.16b}, [pSrc1], stride

    stp     q16, q20, [pDst, #32]!
    subs    count, count, #1
    stp     q17, q21, [pDst, #1*64]
    stp     q18, q22, [pDst, #2*64]
    stp     q19, q23, [pDst, #3*64]
    stp     q24, q28, [pDst, #4*64]
    stp     q25, q29, [pDst, #5*64]
    stp     q26, q30, [pDst, #6*64]
    stp     q27, q31, [pDst, #7*64]
    b.ne    1b
    // 8x64 matrix transpose virtually finished. What a moron needs zip1/zip2/trn for that?
    nop

    sub     pSrc0, pDst, #32
    add     pSrc1, pDst, #256-32
    mov     count, #4
    sub     pDst0, pDst, #32
    add     pDst1, pSrc0, #256

1:
    // 8x64 matrix transpose finished on-the-fly while reloading. Again, who the hell needs permutation instructions when we have ld2/ld3/ld4?
    ld2     {v24.16b, v25.16b}, [pSrc0], #32
    ld2     {v26.16b, v27.16b}, [pSrc1], #32
    ld2     {v28.16b, v29.16b}, [pSrc0], #32
    ld2     {v30.16b, v31.16b}, [pSrc1], #32
    subs    count, count, #1

    // nosy noob shut up remark: the trns below aren't part of the matrix transpose
    trn1    v16.2d, v24.2d, v25.2d  // row0
    trn2    v17.2d, v24.2d, v25.2d  // row1
    trn1    v18.2d, v26.2d, v27.2d  // row2
    trn2    v19.2d, v26.2d, v28.2d  // row3
    trn1    v20.2d, v28.2d, v29.2d  // row4
    trn2    v21.2d, v28.2d, v29.2d  // row5
    trn1    v22.2d, v30.2d, v31.2d  // row6
    trn2    v23.2d, v30.2d, v31.2d  // row7

    mov     v24.16b, v16.16b
    mov     v25.16b, v17.16b
    mov     v26.16b, v18.16b
    mov     v27.16b, v19.16b

    sli     v16.16b, v20.16b, #4
    sli     v17.16b, v21.16b, #4
    sli     v18.16b, v22.16b, #4
    sli     v19.16b, v23.16b, #4
    sri     v20.16b, v24.16b, #4
    sri     v21.16b, v25.16b, #4
    sri     v22.16b, v26.16b, #4
    sri     v23.16b, v27.16b, #4

    shl     v24.16b, v18.16b, #2
    shl     v25.16b, v19.16b, #2
    ushr    v26.16b, v16.16b, #2
    ushr    v27.16b, v17.16b, #2
    shl     v28.16b, v22.16b, #2
    shl     v29.16b, v23.16b, #2
    ushr    v30.16b, v20.16b, #2
    ushr    v31.16b, v21.16b, #2

    bit     v16.16b, v24.16b, v6.16b
    bit     v17.16b, v25.16b, v6.16b
    bif     v18.16b, v26.16b, v6.16b
    bif     v19.16b, v27.16b, v6.16b
    bit     v20.16b, v28.16b, v6.16b
    bit     v21.16b, v29.16b, v6.16b
    bif     v22.16b, v30.16b, v6.16b
    bif     v23.16b, v31.16b, v6.16b

    shl     v24.16b, v17.16b, #1
    ushr    v25.16b, v16.16b, #1
    shl     v26.16b, v19.16b, #1
    ushr    v27.16b, v18.16b, #1
    shl     v28.16b, v21.16b, #1
    ushr    v29.16b, v20.16b, #1
    shl     v30.16b, v23.16b, #1
    ushr    v31.16b, v22.16b, #1

    bit     v16.16b, v24.16b, v7.16b
    bif     v17.16b, v25.16b, v7.16b
    bit     v18.16b, v26.16b, v7.16b
    bif     v19.16b, v27.16b, v7.16b
    bit     v20.16b, v28.16b, v7.16b
    bif     v21.16b, v29.16b, v7.16b
    bit     v22.16b, v30.16b, v7.16b
    bif     v23.16b, v31.16b, v7.16b

    st4     {v16.d, v17.d, v18.d, v19.d}[0], [pDst0], #32
    st4     {v16.d, v17.d, v18.d, v19.d}[1], [pDst1], #32
    st4     {v20.d, v21.d, v22.d, v23.d}[0], [pDst0], #32
    st4     {v20.d, v21.d, v22.d, v23.d}[1], [pDst1], #32
    b.ne    1b

// Everyone has a plan until they get punched in the mouth - Mike Tyson

.balign 16
    ret
.endfunc
.end

It's probably the perfect noob's code: perfectly minimized at zero latency.
I was expecting fuz to deliver something similar, but......

Still, it's a noob version after all, and I would give it a C grade (befriedigend).

Why it doesn't deserve an A grade will be made clear upon next update: bandwidth constraint and power consumption which is the REAL concern on modern multicore processors.


Fuz's code (F grade, mangelhaft):

    # transpose a 64x64 bit matrix held in x0
    GLOBL(xpose_asm)
FUNC(xpose_asm)
    # plan of attack: use registers v16--v32 to hold
    # half the array, v0--v7 for scratch.  First transpose
    # the two array halves individually, then swap the
    # second and third quarters.
    mov x4, lr

    mov x2, x0
    bl  NAME(xpose_half)
    mov x3, x0
    bl  NAME(xpose_half)

    # final step: transpose 64x64 bit matrices
    # we have to do this one in two parts as to not run
    # out of registers
    mov x5, x2
    mov x6, x3
    bl  NAME(xpose_final)
    bl  NAME(xpose_final)

    ret x4
ENDFUNC(xpose_asm)

    # Transpose half a 32x64 bit matrix held in x0.
    # On return, advance x0 by 32*8 = 256 byte.
FUNC(xpose_half)
    # v16 holds rows 0 and 4, v17 holds 1 and 5, and so on
    mov x1, x0
    ld4 {v16.2d, v17.2d, v18.2d, v19.2d}, [x0], #64
    ld4 {v20.2d, v21.2d, v22.2d, v23.2d}, [x0], #64
    ld4 {v24.2d, v25.2d, v26.2d, v27.2d}, [x0], #64
    ld4 {v28.2d, v29.2d, v30.2d, v31.2d}, [x0], #64

    # macro for a transposition step.  Trashes v6 and v7
.macro  xpstep lo, hi, mask, shift
    ushr v6.2d, \lo\().2d, #\shift
    shl v7.2d, \hi\().2d, #\shift
    bif \lo\().16b, v7.16b, \mask\().16b
    bit \hi\().16b, v6.16b, \mask\().16b
.endm

    # 1st step: transpose 2x2 bit matrices
    movi    v0.16b, #0x55
    xpstep  v16, v17, v0, 1
    xpstep  v18, v19, v0, 1
    xpstep  v20, v21, v0, 1
    xpstep  v22, v23, v0, 1
    xpstep  v24, v25, v0, 1
    xpstep  v26, v27, v0, 1
    xpstep  v28, v29, v0, 1
    xpstep  v30, v31, v0, 1

    # 2nd step: transpose 4x4 bit matrices
    movi    v0.16b, #0x33
    xpstep  v16, v18, v0, 2
    xpstep  v17, v19, v0, 2
    xpstep  v20, v22, v0, 2
    xpstep  v21, v23, v0, 2
    xpstep  v24, v26, v0, 2
    xpstep  v25, v27, v0, 2
    xpstep  v28, v30, v0, 2
    xpstep  v29, v31, v0, 2

    # immediate step: zip vectors to change
    # colocation.  As a side effect, every other
    # vector is temporarily relocated to the v0..v7
    # register range
    zip1    v0.2d,  v16.2d, v17.2d
    zip2    v17.2d, v16.2d, v17.2d
    zip1    v1.2d,  v18.2d, v19.2d
    zip2    v19.2d, v18.2d, v19.2d
    zip1    v2.2d,  v20.2d, v21.2d
    zip2    v21.2d, v20.2d, v21.2d
    zip1    v3.2d,  v22.2d, v23.2d
    zip2    v23.2d, v22.2d, v23.2d
    zip1    v4.2d,  v24.2d, v25.2d
    zip2    v25.2d, v24.2d, v25.2d
    zip1    v5.2d,  v26.2d, v27.2d
    zip2    v27.2d, v26.2d, v27.2d
    zip1    v6.2d,  v28.2d, v29.2d
    zip2    v29.2d, v28.2d, v29.2d
    zip1    v7.2d,  v30.2d, v31.2d
    zip2    v31.2d, v30.2d, v31.2d

    # macro for the 3rd transposition step
    # swap low 4 bit of each hi member with
    # high 4 bit of each orig member.  The orig
    # members are copied to lo in the process.
.macro  xpstep3 lo, hi, orig
    mov \lo\().16b, \orig\().16b
    sli \lo\().16b, \hi\().16b, #4
    sri \hi\().16b, \orig\().16b, #4
.endm

    # 3rd step: transpose 8x8 bit matrices
    # special code is needed here since we need to
    # swap row n row line n+4, but these rows are
    # always colocated in the same register
    xpstep3 v16, v17, v0
    xpstep3 v18, v19, v1
    xpstep3 v20, v21, v2
    xpstep3 v22, v23, v3
    xpstep3 v24, v25, v4
    xpstep3 v26, v27, v5
    xpstep3 v28, v29, v6
    xpstep3 v30, v31, v7

    # registers now hold
    # v16: { 0,  1}  v17: { 4,  5}  v18: { 2,  3}  v19: { 6,  7}
    # v20: { 8,  9}  v21: {12, 13}  v22: {10, 11}  v23: {14, 15}
    # v24: {16, 17}  v25: {20, 21}  v26: {18, 19}  v27: {22, 23}
    # v28: {24, 25}  v29: {28, 29}  v30: {26, 27}  v31: {30, 31}

    # 4th step: transpose 16x16 bit matrices
    # this step again moves half the registers to v0--v7
    trn1    v0.16b,  v16.16b, v20.16b
    trn2    v20.16b, v16.16b, v20.16b
    trn1    v1.16b,  v17.16b, v21.16b
    trn2    v21.16b, v17.16b, v21.16b
    trn1    v2.16b,  v18.16b, v22.16b
    trn2    v22.16b, v18.16b, v22.16b
    trn1    v3.16b,  v19.16b, v23.16b
    trn2    v23.16b, v19.16b, v23.16b
    trn1    v4.16b,  v24.16b, v28.16b
    trn2    v28.16b, v24.16b, v28.16b
    trn1    v5.16b,  v25.16b, v29.16b
    trn2    v29.16b, v25.16b, v29.16b
    trn1    v6.16b,  v26.16b, v30.16b
    trn2    v30.16b, v26.16b, v30.16b
    trn1    v7.16b,  v27.16b, v31.16b
    trn2    v31.16b, v27.16b, v31.16b

    # 5th step: transpose 32x32 bit matrices
    # while we are at it, shuffle the order of
    # entries such that they are in order
    trn1    v16.8h, v0.8h, v4.8h
    trn2    v24.8h, v0.8h, v4.8h
    trn1    v18.8h, v1.8h, v5.8h
    trn2    v26.8h, v1.8h, v5.8h
    trn1    v17.8h, v2.8h, v6.8h
    trn2    v25.8h, v2.8h, v6.8h
    trn1    v19.8h, v3.8h, v7.8h
    trn2    v27.8h, v3.8h, v7.8h

    trn1    v0.8h, v20.8h, v28.8h
    trn2    v4.8h, v20.8h, v28.8h
    trn1    v2.8h, v21.8h, v29.8h
    trn2    v6.8h, v21.8h, v29.8h
    trn1    v1.8h, v22.8h, v30.8h
    trn2    v5.8h, v22.8h, v30.8h
    trn1    v3.8h, v23.8h, v31.8h
    trn2    v7.8h, v23.8h, v31.8h

    # now deposit the partially transposed matrix
    st1 {v16.2d, v17.2d, v18.2d, v19.2d}, [x1], #64
    st1 {v0.2d, v1.2d, v2.2d, v3.2d}, [x1], #64
    st1 {v24.2d, v25.2d, v26.2d, v27.2d}, [x1], #64
    st1 {v4.2d, v5.2d, v6.2d, v7.2d}, [x1], #64

    ret
ENDFUNC(xpose_half)

FUNC(xpose_final)
    ld1 {v16.2d, v17.2d, v18.2d, v19.2d}, [x2], #64
    ld1 {v24.2d, v25.2d, v26.2d, v27.2d}, [x3], #64
    ld1 {v20.2d, v21.2d, v22.2d, v23.2d}, [x2], #64
    ld1 {v28.2d, v29.2d, v30.2d, v31.2d}, [x3], #64

    trn1    v0.4s, v16.4s, v24.4s
    trn2    v4.4s, v16.4s, v24.4s
    trn1    v1.4s, v17.4s, v25.4s
    trn2    v5.4s, v17.4s, v25.4s
    trn1    v2.4s, v18.4s, v26.4s
    trn2    v6.4s, v18.4s, v26.4s
    trn1    v3.4s, v19.4s, v27.4s
    trn2    v7.4s, v19.4s, v27.4s

    trn1    v16.4s, v20.4s, v28.4s
    trn2    v24.4s, v20.4s, v28.4s
    trn1    v17.4s, v21.4s, v29.4s
    trn2    v25.4s, v21.4s, v29.4s
    trn1    v18.4s, v22.4s, v30.4s
    trn2    v26.4s, v22.4s, v30.4s
    trn1    v19.4s, v23.4s, v31.4s
    trn2    v27.4s, v23.4s, v31.4s

    st1 {v0.2d, v1.2d, v2.2d, v3.2d}, [x5], #64
    st1 {v4.2d, v5.2d, v6.2d, v7.2d}, [x6], #64
    st1 {v16.2d, v17.2d, v18.2d, v19.2d}, [x5], #64
    st1 {v24.2d, v25.2d, v26.2d, v27.2d}, [x6], #64

    ret
ENDFUNC(xpose_final)

Upvotes: 0

fuz
fuz

Reputation: 93152

The basic recursive scheme for a matrix transposition is to represent the matrix as a block matrix

AB
CD

which you transpose by first transposing each of A, B, C, and D and then swapping B and C. In practice this means applying a sequence of increasingly coarse swizzle steps, first using bitwise operations and later using permutation operations.

This can for example be implemented as follows:

    # transpose a 64x64 bit matrix held in x0
    GLOBL(xpose_asm)
FUNC(xpose_asm)
    # plan of attack: use registers v16--v32 to hold
    # half the array, v0--v7 for scratch.  First transpose
    # the two array halves individually, then swap the
    # second and third quarters.
    mov x4, lr

    mov x2, x0
    bl  NAME(xpose_half)
    mov x3, x0
    bl  NAME(xpose_half)

    # final step: transpose 64x64 bit matrices
    # we have to do this one in two parts as to not run
    # out of registers
    mov x5, x2
    mov x6, x3
    bl  NAME(xpose_final)
    bl  NAME(xpose_final)

    ret x4
ENDFUNC(xpose_asm)

    # Transpose half a 32x64 bit matrix held in x0.
    # On return, advance x0 by 32*8 = 256 byte.
FUNC(xpose_half)
    # v16 holds rows 0 and 4, v17 holds 1 and 5, and so on
    mov x1, x0
    ld4 {v16.2d, v17.2d, v18.2d, v19.2d}, [x0], #64
    ld4 {v20.2d, v21.2d, v22.2d, v23.2d}, [x0], #64
    ld4 {v24.2d, v25.2d, v26.2d, v27.2d}, [x0], #64
    ld4 {v28.2d, v29.2d, v30.2d, v31.2d}, [x0], #64

    # macro for a transposition step.  Trashes v6 and v7
.macro  xpstep lo, hi, mask, shift
    ushr v6.2d, \lo\().2d, #\shift
    shl v7.2d, \hi\().2d, #\shift
    bif \lo\().16b, v7.16b, \mask\().16b
    bit \hi\().16b, v6.16b, \mask\().16b
.endm

    # 1st step: transpose 2x2 bit matrices
    movi    v0.16b, #0x55
    xpstep  v16, v17, v0, 1
    xpstep  v18, v19, v0, 1
    xpstep  v20, v21, v0, 1
    xpstep  v22, v23, v0, 1
    xpstep  v24, v25, v0, 1
    xpstep  v26, v27, v0, 1
    xpstep  v28, v29, v0, 1
    xpstep  v30, v31, v0, 1

    # 2nd step: transpose 4x4 bit matrices
    movi    v0.16b, #0x33
    xpstep  v16, v18, v0, 2
    xpstep  v17, v19, v0, 2
    xpstep  v20, v22, v0, 2
    xpstep  v21, v23, v0, 2
    xpstep  v24, v26, v0, 2
    xpstep  v25, v27, v0, 2
    xpstep  v28, v30, v0, 2
    xpstep  v29, v31, v0, 2

    # immediate step: zip vectors to change
    # colocation.  As a side effect, every other
    # vector is temporarily relocated to the v0..v7
    # register range
    zip1    v0.2d,  v16.2d, v17.2d
    zip2    v17.2d, v16.2d, v17.2d
    zip1    v1.2d,  v18.2d, v19.2d
    zip2    v19.2d, v18.2d, v19.2d
    zip1    v2.2d,  v20.2d, v21.2d
    zip2    v21.2d, v20.2d, v21.2d
    zip1    v3.2d,  v22.2d, v23.2d
    zip2    v23.2d, v22.2d, v23.2d
    zip1    v4.2d,  v24.2d, v25.2d
    zip2    v25.2d, v24.2d, v25.2d
    zip1    v5.2d,  v26.2d, v27.2d
    zip2    v27.2d, v26.2d, v27.2d
    zip1    v6.2d,  v28.2d, v29.2d
    zip2    v29.2d, v28.2d, v29.2d
    zip1    v7.2d,  v30.2d, v31.2d
    zip2    v31.2d, v30.2d, v31.2d

    # macro for the 3rd transposition step
    # swap low 4 bit of each hi member with
    # high 4 bit of each orig member.  The orig
    # members are copied to lo in the process.
.macro  xpstep3 lo, hi, orig
    mov \lo\().16b, \orig\().16b
    sli \lo\().16b, \hi\().16b, #4
    sri \hi\().16b, \orig\().16b, #4
.endm

    # 3rd step: transpose 8x8 bit matrices
    # special code is needed here since we need to
    # swap row n row line n+4, but these rows are
    # always colocated in the same register
    xpstep3 v16, v17, v0
    xpstep3 v18, v19, v1
    xpstep3 v20, v21, v2
    xpstep3 v22, v23, v3
    xpstep3 v24, v25, v4
    xpstep3 v26, v27, v5
    xpstep3 v28, v29, v6
    xpstep3 v30, v31, v7

    # registers now hold
    # v16: { 0,  1}  v17: { 4,  5}  v18: { 2,  3}  v19: { 6,  7}
    # v20: { 8,  9}  v21: {12, 13}  v22: {10, 11}  v23: {14, 15}
    # v24: {16, 17}  v25: {20, 21}  v26: {18, 19}  v27: {22, 23}
    # v28: {24, 25}  v29: {28, 29}  v30: {26, 27}  v31: {30, 31}

    # 4th step: transpose 16x16 bit matrices
    # this step again moves half the registers to v0--v7
    trn1    v0.16b,  v16.16b, v20.16b
    trn2    v20.16b, v16.16b, v20.16b
    trn1    v1.16b,  v17.16b, v21.16b
    trn2    v21.16b, v17.16b, v21.16b
    trn1    v2.16b,  v18.16b, v22.16b
    trn2    v22.16b, v18.16b, v22.16b
    trn1    v3.16b,  v19.16b, v23.16b
    trn2    v23.16b, v19.16b, v23.16b
    trn1    v4.16b,  v24.16b, v28.16b
    trn2    v28.16b, v24.16b, v28.16b
    trn1    v5.16b,  v25.16b, v29.16b
    trn2    v29.16b, v25.16b, v29.16b
    trn1    v6.16b,  v26.16b, v30.16b
    trn2    v30.16b, v26.16b, v30.16b
    trn1    v7.16b,  v27.16b, v31.16b
    trn2    v31.16b, v27.16b, v31.16b

    # 5th step: transpose 32x32 bit matrices
    # while we are at it, shuffle the order of
    # entries such that they are in order
    trn1    v16.8h, v0.8h, v4.8h
    trn2    v24.8h, v0.8h, v4.8h
    trn1    v18.8h, v1.8h, v5.8h
    trn2    v26.8h, v1.8h, v5.8h
    trn1    v17.8h, v2.8h, v6.8h
    trn2    v25.8h, v2.8h, v6.8h
    trn1    v19.8h, v3.8h, v7.8h
    trn2    v27.8h, v3.8h, v7.8h

    trn1    v0.8h, v20.8h, v28.8h
    trn2    v4.8h, v20.8h, v28.8h
    trn1    v2.8h, v21.8h, v29.8h
    trn2    v6.8h, v21.8h, v29.8h
    trn1    v1.8h, v22.8h, v30.8h
    trn2    v5.8h, v22.8h, v30.8h
    trn1    v3.8h, v23.8h, v31.8h
    trn2    v7.8h, v23.8h, v31.8h

    # now deposit the partially transposed matrix
    st1 {v16.2d, v17.2d, v18.2d, v19.2d}, [x1], #64
    st1 {v0.2d, v1.2d, v2.2d, v3.2d}, [x1], #64
    st1 {v24.2d, v25.2d, v26.2d, v27.2d}, [x1], #64
    st1 {v4.2d, v5.2d, v6.2d, v7.2d}, [x1], #64

    ret
ENDFUNC(xpose_half)

FUNC(xpose_final)
    ld1 {v16.2d, v17.2d, v18.2d, v19.2d}, [x2], #64
    ld1 {v24.2d, v25.2d, v26.2d, v27.2d}, [x3], #64
    ld1 {v20.2d, v21.2d, v22.2d, v23.2d}, [x2], #64
    ld1 {v28.2d, v29.2d, v30.2d, v31.2d}, [x3], #64

    trn1    v0.4s, v16.4s, v24.4s
    trn2    v4.4s, v16.4s, v24.4s
    trn1    v1.4s, v17.4s, v25.4s
    trn2    v5.4s, v17.4s, v25.4s
    trn1    v2.4s, v18.4s, v26.4s
    trn2    v6.4s, v18.4s, v26.4s
    trn1    v3.4s, v19.4s, v27.4s
    trn2    v7.4s, v19.4s, v27.4s

    trn1    v16.4s, v20.4s, v28.4s
    trn2    v24.4s, v20.4s, v28.4s
    trn1    v17.4s, v21.4s, v29.4s
    trn2    v25.4s, v21.4s, v29.4s
    trn1    v18.4s, v22.4s, v30.4s
    trn2    v26.4s, v22.4s, v30.4s
    trn1    v19.4s, v23.4s, v31.4s
    trn2    v27.4s, v23.4s, v31.4s

    st1 {v0.2d, v1.2d, v2.2d, v3.2d}, [x5], #64
    st1 {v4.2d, v5.2d, v6.2d, v7.2d}, [x6], #64
    st1 {v16.2d, v17.2d, v18.2d, v19.2d}, [x5], #64
    st1 {v24.2d, v25.2d, v26.2d, v27.2d}, [x6], #64

    ret
ENDFUNC(xpose_final)

We can see that the performance compares well to Lee's approach, being about three times faster.

# Apple M1
name  time/op
Ref      764ns ± 0%
Lee      102ns ± 0%
Fuz     34.7ns ± 0%

name  speed
Ref    670MB/s ± 0%
Lee   5.01GB/s ± 0%
Fuz   14.7GB/s ± 0%


# Kunpeng 920
name  time/op
Ref     3.73µs ± 0%
Lee      391ns ± 1%
Fuz     96.0ns ± 0%

name  speed
Ref    137MB/s ± 0%
Lee   1.31GB/s ± 1%
Fuz   5.33GB/s ± 0%


# ARM Cortex A72
name  time/op
Ref     8.13µs ± 0%
Lee      892ns ± 0%
Fuz      296ns ± 0%

name  speed
Ref   63.0MB/s ± 0%
Lee    574MB/s ± 0%
Fuz   1.73GB/s ± 0%


# Cavium ThunderX
name  time/op
Ref     19.7µs ± 0%
Lee     1.15µs ± 0%
Fuz      690ns ± 0%

name  speed
Ref   25.9MB/s ± 0%
Lee    444MB/s ± 0%
Fuz    742MB/s ± 0%

Further improvements are likely possible. For example, a suitable permutation mask could be used with the tbl set of instructions to perform multiple transposition steps (especially steps 3 to 5) at once.

Note that the algorithm has to load and write out the array just twice. Once to transpose every 32x32 sub array (the two calls to xpose_half) and once more to swap the top right with the bottom left 32x32 sub array. In both cases, maximum width 64 byte loads and stores were used, reducing the amount of memory operations to a minimum.

Upvotes: 6

The data size far exceeds the size of the register bank. You have a choice between:

  • strided load and consecutive store
  • consecutive load and strided store

And consecutive store is always much more preferrable.

#include <arm_neon.h>    
void transposeBitwise64x64(uint64_t *pDst, uint64_t *pSrc)
    {
        uint8x8_t drow0, drow1, drow2, drow3, drow4, drow5, drow6, drow7;
        uint8x8_t dtmp0, dtmp1, dtmp2, dtmp3, dtmp4, dtmp5, dtmp6, dtmp7;
        uint8x16_t qrow0, qrow1, qrow2, qrow3, qrow4, qrow5, qrow6, qrow7;
        uint8x16_t qtmp0, qtmp1, qtmp2, qtmp3, qtmp4, qtmp5, qtmp6, qtmp7;
        const intptr_t sstride = 16;
        uint8_t *pSrc1, *pSrc2, *pSrcBase;
        uint32_t count = 8;
    
        drow0 = vmov_n_u8(0);
        drow1 = vmov_n_u8(0);
        drow2 = vmov_n_u8(0);
        drow3 = vmov_n_u8(0);
        drow4 = vmov_n_u8(0);
        drow5 = vmov_n_u8(0);
        drow6 = vmov_n_u8(0);
        drow7 = vmov_n_u8(0);
    
        pSrcBase = (uint8_t *) pSrc;
    
        do {
            pSrc1 = pSrcBase;
            pSrc2 = pSrcBase + 8;
            pSrcBase += 1;
            drow0 = vld1_lane_u8(pSrc1, drow0, 0); pSrc1 += sstride;
            drow1 = vld1_lane_u8(pSrc2, drow1, 0); pSrc2 += sstride;
            drow2 = vld1_lane_u8(pSrc1, drow2, 0); pSrc1 += sstride;
            drow3 = vld1_lane_u8(pSrc2, drow3, 0); pSrc2 += sstride;
            drow4 = vld1_lane_u8(pSrc1, drow4, 0); pSrc1 += sstride;
            drow5 = vld1_lane_u8(pSrc2, drow5, 0); pSrc2 += sstride;
            drow6 = vld1_lane_u8(pSrc1, drow6, 0); pSrc1 += sstride;
            drow7 = vld1_lane_u8(pSrc2, drow7, 0); pSrc2 += sstride;
            drow0 = vld1_lane_u8(pSrc1, drow0, 1); pSrc1 += sstride;
            drow1 = vld1_lane_u8(pSrc2, drow1, 1); pSrc2 += sstride;
            drow2 = vld1_lane_u8(pSrc1, drow2, 1); pSrc1 += sstride;
            drow3 = vld1_lane_u8(pSrc2, drow3, 1); pSrc2 += sstride;
            drow4 = vld1_lane_u8(pSrc1, drow4, 1); pSrc1 += sstride;
            drow5 = vld1_lane_u8(pSrc2, drow5, 1); pSrc2 += sstride;
            drow6 = vld1_lane_u8(pSrc1, drow6, 1); pSrc1 += sstride;
            drow7 = vld1_lane_u8(pSrc2, drow7, 1); pSrc2 += sstride;
            drow0 = vld1_lane_u8(pSrc1, drow0, 2); pSrc1 += sstride;
            drow1 = vld1_lane_u8(pSrc2, drow1, 2); pSrc2 += sstride;
            drow2 = vld1_lane_u8(pSrc1, drow2, 2); pSrc1 += sstride;
            drow3 = vld1_lane_u8(pSrc2, drow3, 2); pSrc2 += sstride;
            drow4 = vld1_lane_u8(pSrc1, drow4, 2); pSrc1 += sstride;
            drow5 = vld1_lane_u8(pSrc2, drow5, 2); pSrc2 += sstride;
            drow6 = vld1_lane_u8(pSrc1, drow6, 2); pSrc1 += sstride;
            drow7 = vld1_lane_u8(pSrc2, drow7, 2); pSrc2 += sstride;
            drow0 = vld1_lane_u8(pSrc1, drow0, 3); pSrc1 += sstride;
            drow1 = vld1_lane_u8(pSrc2, drow1, 3); pSrc2 += sstride;
            drow2 = vld1_lane_u8(pSrc1, drow2, 3); pSrc1 += sstride;
            drow3 = vld1_lane_u8(pSrc2, drow3, 3); pSrc2 += sstride;
            drow4 = vld1_lane_u8(pSrc1, drow4, 3); pSrc1 += sstride;
            drow5 = vld1_lane_u8(pSrc2, drow5, 3); pSrc2 += sstride;
            drow6 = vld1_lane_u8(pSrc1, drow6, 3); pSrc1 += sstride;
            drow7 = vld1_lane_u8(pSrc2, drow7, 3); pSrc2 += sstride;
            drow0 = vld1_lane_u8(pSrc1, drow0, 4); pSrc1 += sstride;
            drow1 = vld1_lane_u8(pSrc2, drow1, 4); pSrc2 += sstride;
            drow2 = vld1_lane_u8(pSrc1, drow2, 4); pSrc1 += sstride;
            drow3 = vld1_lane_u8(pSrc2, drow3, 4); pSrc2 += sstride;
            drow4 = vld1_lane_u8(pSrc1, drow4, 4); pSrc1 += sstride;
            drow5 = vld1_lane_u8(pSrc2, drow5, 4); pSrc2 += sstride;
            drow6 = vld1_lane_u8(pSrc1, drow6, 4); pSrc1 += sstride;
            drow7 = vld1_lane_u8(pSrc2, drow7, 4); pSrc2 += sstride;
            drow0 = vld1_lane_u8(pSrc1, drow0, 5); pSrc1 += sstride;
            drow1 = vld1_lane_u8(pSrc2, drow1, 5); pSrc2 += sstride;
            drow2 = vld1_lane_u8(pSrc1, drow2, 5); pSrc1 += sstride;
            drow3 = vld1_lane_u8(pSrc2, drow3, 5); pSrc2 += sstride;
            drow4 = vld1_lane_u8(pSrc1, drow4, 5); pSrc1 += sstride;
            drow5 = vld1_lane_u8(pSrc2, drow5, 5); pSrc2 += sstride;
            drow6 = vld1_lane_u8(pSrc1, drow6, 5); pSrc1 += sstride;
            drow7 = vld1_lane_u8(pSrc2, drow7, 5); pSrc2 += sstride;
            drow0 = vld1_lane_u8(pSrc1, drow0, 6); pSrc1 += sstride;
            drow1 = vld1_lane_u8(pSrc2, drow1, 6); pSrc2 += sstride;
            drow2 = vld1_lane_u8(pSrc1, drow2, 6); pSrc1 += sstride;
            drow3 = vld1_lane_u8(pSrc2, drow3, 6); pSrc2 += sstride;
            drow4 = vld1_lane_u8(pSrc1, drow4, 6); pSrc1 += sstride;
            drow5 = vld1_lane_u8(pSrc2, drow5, 6); pSrc2 += sstride;
            drow6 = vld1_lane_u8(pSrc1, drow6, 6); pSrc1 += sstride;
            drow7 = vld1_lane_u8(pSrc2, drow7, 6); pSrc2 += sstride;
            drow0 = vld1_lane_u8(pSrc1, drow0, 7); pSrc1 += sstride;
            drow1 = vld1_lane_u8(pSrc2, drow1, 7); pSrc2 += sstride;
            drow2 = vld1_lane_u8(pSrc1, drow2, 7); pSrc1 += sstride;
            drow3 = vld1_lane_u8(pSrc2, drow3, 7); pSrc2 += sstride;
            drow4 = vld1_lane_u8(pSrc1, drow4, 7); pSrc1 += sstride;
            drow5 = vld1_lane_u8(pSrc2, drow5, 7); pSrc2 += sstride;
            drow6 = vld1_lane_u8(pSrc1, drow6, 7);
            drow7 = vld1_lane_u8(pSrc2, drow7, 7);
    
            dtmp0 = vshr_n_u8(drow0, 1);
            dtmp1 = vshr_n_u8(drow1, 1);
            dtmp2 = vshr_n_u8(drow2, 1);
            dtmp3 = vshr_n_u8(drow3, 1);
            dtmp4 = vshr_n_u8(drow4, 1);
            dtmp5 = vshr_n_u8(drow5, 1);
            dtmp6 = vshr_n_u8(drow6, 1);
            dtmp7 = vshr_n_u8(drow7, 1);
    
            qrow0 = vcombine_u8(drow0, dtmp0);
            qrow1 = vcombine_u8(drow1, dtmp1);
            qrow2 = vcombine_u8(drow2, dtmp2);
            qrow3 = vcombine_u8(drow3, dtmp3);
            qrow4 = vcombine_u8(drow4, dtmp4);
            qrow5 = vcombine_u8(drow5, dtmp5);
            qrow6 = vcombine_u8(drow6, dtmp6);
            qrow7 = vcombine_u8(drow7, dtmp7);
    
    //////////////////////////////////////
    
            qtmp0 = qrow0;
            qtmp1 = qrow1;
            qtmp2 = qrow2;
            qtmp3 = qrow3;
            qtmp4 = qrow4;
            qtmp5 = qrow5;
            qtmp6 = qrow6;
            qtmp7 = qrow7;
    
            qtmp0 = vsliq_n_u8(qtmp0, qtmp1, 1);
            qtmp2 = vsliq_n_u8(qtmp2, qtmp3, 1);
            qtmp4 = vsliq_n_u8(qtmp4, qtmp5, 1);
            qtmp6 = vsliq_n_u8(qtmp6, qtmp7, 1);
    
            qtmp0 = vsliq_n_u8(qtmp0, qtmp2, 2);
            qtmp4 = vsliq_n_u8(qtmp4, qtmp6, 2);
    
            qtmp0 = vsliq_n_u8(qtmp0, qtmp4, 4);
    
            vst1q_u8((uint8_t *)pDst, qtmp0); pDst += 2;
    
    //////////////////////////////////////
    
            qtmp0 = vshrq_n_u8(qrow0, 2);
            qtmp1 = vshrq_n_u8(qrow1, 2);
            qtmp2 = vshrq_n_u8(qrow2, 2);
            qtmp3 = vshrq_n_u8(qrow3, 2);
            qtmp4 = vshrq_n_u8(qrow4, 2);
            qtmp5 = vshrq_n_u8(qrow5, 2);
            qtmp6 = vshrq_n_u8(qrow6, 2);
            qtmp7 = vshrq_n_u8(qrow7, 2);
    
            qtmp0 = vsliq_n_u8(qtmp0, qtmp1, 1);
            qtmp2 = vsliq_n_u8(qtmp2, qtmp3, 1);
            qtmp4 = vsliq_n_u8(qtmp4, qtmp5, 1);
            qtmp6 = vsliq_n_u8(qtmp6, qtmp7, 1);
    
            qtmp0 = vsliq_n_u8(qtmp0, qtmp2, 2);
            qtmp4 = vsliq_n_u8(qtmp4, qtmp6, 2);
    
            qtmp0 = vsliq_n_u8(qtmp0, qtmp4, 4);
    
            vst1q_u8((uint8_t *)pDst, qtmp0); pDst += 2;
    
            //////////////////////////////////////
    
            qtmp0 = vshrq_n_u8(qrow0, 4);
            qtmp1 = vshrq_n_u8(qrow1, 4);
            qtmp2 = vshrq_n_u8(qrow2, 4);
            qtmp3 = vshrq_n_u8(qrow3, 4);
            qtmp4 = vshrq_n_u8(qrow4, 4);
            qtmp5 = vshrq_n_u8(qrow5, 4);
            qtmp6 = vshrq_n_u8(qrow6, 4);
            qtmp7 = vshrq_n_u8(qrow7, 4);
    
            qtmp0 = vsliq_n_u8(qtmp0, qtmp1, 1);
            qtmp2 = vsliq_n_u8(qtmp2, qtmp3, 1);
            qtmp4 = vsliq_n_u8(qtmp4, qtmp5, 1);
            qtmp6 = vsliq_n_u8(qtmp6, qtmp7, 1);
    
            qtmp0 = vsliq_n_u8(qtmp0, qtmp2, 2);
            qtmp4 = vsliq_n_u8(qtmp4, qtmp6, 2);
    
            qtmp0 = vsliq_n_u8(qtmp0, qtmp4, 4);
    
            vst1q_u8((uint8_t *)pDst, qtmp0); pDst += 2;
    
            //////////////////////////////////////
    
            qtmp0 = vshrq_n_u8(qrow0, 6);
            qtmp1 = vshrq_n_u8(qrow1, 6);
            qtmp2 = vshrq_n_u8(qrow2, 6);
            qtmp3 = vshrq_n_u8(qrow3, 6);
            qtmp4 = vshrq_n_u8(qrow4, 6);
            qtmp5 = vshrq_n_u8(qrow5, 6);
            qtmp6 = vshrq_n_u8(qrow6, 6);
            qtmp7 = vshrq_n_u8(qrow7, 6);
    
            qtmp0 = vsliq_n_u8(qtmp0, qtmp1, 1);
            qtmp2 = vsliq_n_u8(qtmp2, qtmp3, 1);
            qtmp4 = vsliq_n_u8(qtmp4, qtmp5, 1);
            qtmp6 = vsliq_n_u8(qtmp6, qtmp7, 1);
    
            qtmp0 = vsliq_n_u8(qtmp0, qtmp2, 2);
            qtmp4 = vsliq_n_u8(qtmp4, qtmp6, 2);
    
            qtmp0 = vsliq_n_u8(qtmp0, qtmp4, 4);
    
            vst1q_u8((uint8_t *)pDst, qtmp0); pDst += 2;
    
        } while (--count);
    }

I tried my best to talk the compiler into generating optimized machine codes, but they simply won't listen: godbolt Especially GCC sucks (as always).
I'll add an assembly version by tomorrow.

Upvotes: -2

Related Questions