Reputation: 15070
I found the following code in C++ for fast transposition of an 8x8 matrix of 32-bit values: https://stackoverflow.com/a/51887176/1915854
inline void Transpose8x8Shuff(unsigned long *in)
{
__m256 *inI = reinterpret_cast<__m256 *>(in);
__m256 rI[8];
rI[0] = _mm256_unpacklo_ps(inI[0], inI[1]);
rI[1] = _mm256_unpackhi_ps(inI[0], inI[1]);
rI[2] = _mm256_unpacklo_ps(inI[2], inI[3]);
rI[3] = _mm256_unpackhi_ps(inI[2], inI[3]);
rI[4] = _mm256_unpacklo_ps(inI[4], inI[5]);
rI[5] = _mm256_unpackhi_ps(inI[4], inI[5]);
rI[6] = _mm256_unpacklo_ps(inI[6], inI[7]);
rI[7] = _mm256_unpackhi_ps(inI[6], inI[7]);
__m256 rrF[8];
__m256 *rF = reinterpret_cast<__m256 *>(rI);
rrF[0] = _mm256_shuffle_ps(rF[0], rF[2], _MM_SHUFFLE(1,0,1,0));
rrF[1] = _mm256_shuffle_ps(rF[0], rF[2], _MM_SHUFFLE(3,2,3,2));
rrF[2] = _mm256_shuffle_ps(rF[1], rF[3], _MM_SHUFFLE(1,0,1,0));
rrF[3] = _mm256_shuffle_ps(rF[1], rF[3], _MM_SHUFFLE(3,2,3,2));
rrF[4] = _mm256_shuffle_ps(rF[4], rF[6], _MM_SHUFFLE(1,0,1,0));
rrF[5] = _mm256_shuffle_ps(rF[4], rF[6], _MM_SHUFFLE(3,2,3,2));
rrF[6] = _mm256_shuffle_ps(rF[5], rF[7], _MM_SHUFFLE(1,0,1,0));
rrF[7] = _mm256_shuffle_ps(rF[5], rF[7], _MM_SHUFFLE(3,2,3,2));
rF = reinterpret_cast<__m256 *>(in);
rF[0] = _mm256_permute2f128_ps(rrF[0], rrF[4], 0x20);
rF[1] = _mm256_permute2f128_ps(rrF[1], rrF[5], 0x20);
rF[2] = _mm256_permute2f128_ps(rrF[2], rrF[6], 0x20);
rF[3] = _mm256_permute2f128_ps(rrF[3], rrF[7], 0x20);
rF[4] = _mm256_permute2f128_ps(rrF[0], rrF[4], 0x31);
rF[5] = _mm256_permute2f128_ps(rrF[1], rrF[5], 0x31);
rF[6] = _mm256_permute2f128_ps(rrF[2], rrF[6], 0x31);
rF[7] = _mm256_permute2f128_ps(rrF[3], rrF[7], 0x31);
}
However, converting it to Java vector API ( https://download.java.net/java/early_access/panama/docs/api/jdk.incubator.vector/jdk/incubator/vector/IntVector.html ) is not straightforward, because the Java vector API doesn't map directly to CPU instructions / C++ intrinsics.
Can you share what the equivalents of the following intrinsics/macros in Java are?
_mm256_unpacklo_ps()
_mm256_unpackhi_ps()
_mm256_shuffle_ps()
_MM_SHUFFLE()
_mm256_permute2f128_ps()
I can use the latest JDK 19.
UPDATE: following the suggestion by @Soonts , I've implemented the following, and it passes tests, but it's terribly slow:
public class SimdOps {
public static final VectorSpecies<Integer> SPECIES_INT = IntVector.SPECIES_256;
public static final VectorSpecies<Long> SPECIES_LONG = LongVector.SPECIES_256;
public static final VectorShuffle<Integer> vsUnpackLo = VectorShuffle.fromValues(SPECIES_INT, 0, -8, 1, -7, 4, -4,
5, -3);
public static final VectorShuffle<Integer> vsUnpackHi = VectorShuffle.fromValues(SPECIES_INT, 2, -6, 3, -5, 6, -2,
7, -1);
public static final VectorShuffle<Integer> vsShuffle1010 = VectorShuffle.fromValues(SPECIES_INT, 0, 1, -8, -7, 4,
5, -4, -3);
public static final VectorShuffle<Integer> vsShuffle3232 = VectorShuffle.fromValues(SPECIES_INT, 2, 3, -6, -5, 6, 7,
-2, -1);
public static final VectorShuffle<Integer> vsPermute0x20 = VectorShuffle.fromValues(SPECIES_INT, 0, 1, 2, 3, -8, -7,
-6, -5);
public static final VectorShuffle<Integer> vsPermute0x31 = VectorShuffle.fromValues(SPECIES_INT, 4, 5, 6, 7, -4, -3,
-2, -1);
// Transpose 8x8 matrix of 32-bit integers, stored in 256-bit SIMD vectors
public static final void transpose8x8(IntVector[] inpM) {
assert inpM.length == Constants.INTS_PER_SIMD;
// https://stackoverflow.com/questions/25622745/transpose-an-8x8-float-using-avx-avx2
// https://stackoverflow.com/questions/73977998/simd-transposition-of-8x8-matrix-of-32-bit-values-in-java
final IntVector rI0 = inpM[0].rearrange(vsUnpackLo, inpM[1]);
final IntVector rI1 = inpM[0].rearrange(vsUnpackHi, inpM[1]);
final IntVector rI2 = inpM[2].rearrange(vsUnpackLo, inpM[3]);
final IntVector rI3 = inpM[2].rearrange(vsUnpackHi, inpM[3]);
final IntVector rI4 = inpM[4].rearrange(vsUnpackLo, inpM[5]);
final IntVector rI5 = inpM[4].rearrange(vsUnpackHi, inpM[5]);
final IntVector rI6 = inpM[6].rearrange(vsUnpackLo, inpM[7]);
final IntVector rI7 = inpM[6].rearrange(vsUnpackHi, inpM[7]);
final IntVector rrF0 = rI0.rearrange(vsShuffle1010, rI2);
final IntVector rrF1 = rI0.rearrange(vsShuffle3232, rI2);
final IntVector rrF2 = rI1.rearrange(vsShuffle1010, rI3);
final IntVector rrF3 = rI1.rearrange(vsShuffle3232, rI3);
final IntVector rrF4 = rI4.rearrange(vsShuffle1010, rI6);
final IntVector rrF5 = rI4.rearrange(vsShuffle3232, rI6);
final IntVector rrF6 = rI5.rearrange(vsShuffle1010, rI7);
final IntVector rrF7 = rI5.rearrange(vsShuffle3232, rI7);
inpM[0] = rrF0.rearrange(vsPermute0x20, rrF4);
inpM[1] = rrF1.rearrange(vsPermute0x20, rrF5);
inpM[2] = rrF2.rearrange(vsPermute0x20, rrF6);
inpM[3] = rrF3.rearrange(vsPermute0x20, rrF7);
inpM[4] = rrF0.rearrange(vsPermute0x31, rrF4);
inpM[5] = rrF1.rearrange(vsPermute0x31, rrF5);
inpM[6] = rrF2.rearrange(vsPermute0x31, rrF6);
inpM[7] = rrF3.rearrange(vsPermute0x31, rrF7);
}
};
And the bottleneck is jdk.incubator.vector.Int256Vector.rearrange(VectorShuffle, Vector)
. It's at least 10 times slower than the scalar code. Any ideas?
Upvotes: 1
Views: 289
Reputation: 21956
Disclaimer: I never wrote anything similar in Java.
Based on the documentation, the rearrange seems the only way to go.
The only issue is how to translate C intrinsics into the integers for the VectorShuffle<Float>
.
Here's C++ code to find out:
void printShuffle( __m256 v, const char* name )
{
__m256i iv = _mm256_cvtps_epi32( v );
std::array<int, 8> a;
_mm256_storeu_si256( ( __m256i* )a.data(), iv );
printf( "%s: %i, %i, %i, %i, %i, %i, %i, %i\n", name,
a[ 0 ], a[ 1 ], a[ 2 ], a[ 3 ], a[ 4 ], a[ 5 ], a[ 6 ], a[ 7 ] );
}
#define TEST( expr ) printShuffle( expr, #expr )
void printJavaRearranges()
{
const __m256 a = _mm256_setr_ps( 0, 1, 2, 3, 4, 5, 6, 7 );
const __m256 b = _mm256_sub_ps( a, _mm256_set1_ps( 8 ) );
TEST( _mm256_unpacklo_ps( a, b ) );
TEST( _mm256_unpackhi_ps( a, b ) );
TEST( _mm256_shuffle_ps( a, b, _MM_SHUFFLE(1,0,1,0) ) );
TEST( _mm256_shuffle_ps( a, b, _MM_SHUFFLE(3,2,3,2) ) );
TEST( _mm256_permute2f128_ps( a, b, 0x20 ) );
TEST( _mm256_permute2f128_ps( a, b, 0x31 ) );
}
Output:
_mm256_unpacklo_ps( a, b ): 0, -8, 1, -7, 4, -4, 5, -3
_mm256_unpackhi_ps( a, b ): 2, -6, 3, -5, 6, -2, 7, -1
_mm256_shuffle_ps( a, b, _MM_SHUFFLE(1,0,1,0) ): 0, 1, -8, -7, 4, 5, -4, -3
_mm256_shuffle_ps( a, b, _MM_SHUFFLE(3,2,3,2) ): 2, 3, -6, -5, 6, 7, -2, -1
_mm256_permute2f128_ps( a, b, 0x20 ): 0, 1, 2, 3, -8, -7, -6, -5
_mm256_permute2f128_ps( a, b, 0x31 ): 4, 5, 6, 7, -4, -3, -2, -1
The _mm256_permute2f128_ps
instruction can selectively zero out lanes, Java's vector API probably can't do that. Fortunately, the immediate values in your source code don't zero out any pieces.
If you're lucky, the runtime might map these values (when they are known to JIT in advance and never change) into the corresponding AVX instructions.
Upvotes: 2