Reputation: 7006
Given the following input bytes:
var vBytes = new Vector<byte>(new byte[] {72, 101, 55, 08, 108, 111, 55, 87, 111, 114, 108, 55, 100, 55, 55, 20});
And the given mask:
var mask = new Vector<byte>(55);
How can I find the count of byte 55
in the input array?
I have tried xoring the vBytes
with the mask
var xored = Vector.Xor(mask, vBytes);
which gives:
<127, 82, 0, 91, 91, 88, 0, 96, 88, 69, 91, 0, 83, 0, 0, 35>
But don't know how I can get the count from that.
For the sake of simplicity let's assume that the input byte length is always equal to the size of Vector<byte>.Count
Upvotes: 2
Views: 2383
Reputation: 5080
I know that I'm super late to the party, but so far none of the answers here actually provide a full solution. Here's my best attempt at one, derived from this Gist and the DotNet source code. All credit goes to the DotNet team and community members here (especially @Peter Cordes).
var bytes = Encoding.ASCII.GetBytes("The quick brown fox jumps over the lazy dog.");
var byteCount = bytes.OccurrencesOf(32);
var chars = "The quick brown fox jumps over the lazy dog.";
var charCount = chars.OccurrencesOf(' ');
public static class VectorExtensions
private static nuint GetByteVector128SpanLength(nuint offset, int length) =>
((nuint)(uint)((length - (int)offset) & ~(Vector128<byte>.Count - 1)));
private static nuint GetByteVector256SpanLength(nuint offset, int length) =>
((nuint)(uint)((length - (int)offset) & ~(Vector256<byte>.Count - 1)));
private static nint GetCharVector128SpanLength(nint offset, nint length) =>
((length - offset) & ~(Vector128<ushort>.Count - 1));
private static nint GetCharVector256SpanLength(nint offset, nint length) =>
((length - offset) & ~(Vector256<ushort>.Count - 1));
private static Vector128<byte> LoadVector128(ref byte start, nuint offset) =>
Unsafe.ReadUnaligned<Vector128<byte>>(ref Unsafe.AddByteOffset(ref start, offset));
private static Vector256<byte> LoadVector256(ref byte start, nuint offset) =>
Unsafe.ReadUnaligned<Vector256<byte>>(ref Unsafe.AddByteOffset(ref start, offset));
private static Vector128<ushort> LoadVector128(ref char start, nint offset) =>
Unsafe.ReadUnaligned<Vector128<ushort>>(ref Unsafe.As<char, byte>(ref Unsafe.Add(ref start, offset)));
private static Vector256<ushort> LoadVector256(ref char start, nint offset) =>
Unsafe.ReadUnaligned<Vector256<ushort>>(ref Unsafe.As<char, byte>(ref Unsafe.Add(ref start, offset)));
private static unsafe int OccurrencesOf(ref byte searchSpace, byte value, int length) {
var lengthToExamine = ((nuint)length);
var offset = ((nuint)0);
var result = 0L;
if (Sse2.IsSupported || Avx2.IsSupported) {
if (31 < length) {
lengthToExamine = UnalignedCountVector128(ref searchSpace);
while (7 < lengthToExamine) {
ref byte current = ref Unsafe.AddByteOffset(ref searchSpace, offset);
if (value == current) {
if (value == Unsafe.AddByteOffset(ref current, 1)) {
if (value == Unsafe.AddByteOffset(ref current, 2)) {
if (value == Unsafe.AddByteOffset(ref current, 3)) {
if (value == Unsafe.AddByteOffset(ref current, 4)) {
if (value == Unsafe.AddByteOffset(ref current, 5)) {
if (value == Unsafe.AddByteOffset(ref current, 6)) {
if (value == Unsafe.AddByteOffset(ref current, 7)) {
lengthToExamine -= 8;
offset += 8;
while (3 < lengthToExamine) {
ref byte current = ref Unsafe.AddByteOffset(ref searchSpace, offset);
if (value == current) {
if (value == Unsafe.AddByteOffset(ref current, 1)) {
if (value == Unsafe.AddByteOffset(ref current, 2)) {
if (value == Unsafe.AddByteOffset(ref current, 3)) {
lengthToExamine -= 4;
offset += 4;
while (0 < lengthToExamine) {
if (value == Unsafe.AddByteOffset(ref searchSpace, offset)) {
if (offset < ((nuint)(uint)length)) {
if (Avx2.IsSupported) {
if (0 != (((nuint)(uint)Unsafe.AsPointer(ref searchSpace) + offset) & (nuint)(Vector256<byte>.Count - 1))) {
var sum = Sse2.SumAbsoluteDifferences(Sse2.Subtract(Vector128<byte>.Zero, Sse2.CompareEqual(Vector128.Create(value), LoadVector128(ref searchSpace, offset))).AsByte(), Vector128<byte>.Zero).AsInt64();
offset += 16;
result += (sum.GetElement(0) + sum.GetElement(1));
lengthToExamine = GetByteVector256SpanLength(offset, length);
var searchMask = Vector256.Create(value);
if (127 < lengthToExamine) {
var sum = Vector256<long>.Zero;
do {
var accumulator0 = Vector256<byte>.Zero;
var accumulator1 = Vector256<byte>.Zero;
var accumulator2 = Vector256<byte>.Zero;
var accumulator3 = Vector256<byte>.Zero;
var loopIndex = ((nuint)0);
var loopLimit = Math.Min(255, (lengthToExamine / 128));
do {
accumulator0 = Avx2.Subtract(accumulator0, Avx2.CompareEqual(searchMask, LoadVector256(ref searchSpace, offset)));
accumulator1 = Avx2.Subtract(accumulator1, Avx2.CompareEqual(searchMask, LoadVector256(ref searchSpace, (offset + 32))));
accumulator2 = Avx2.Subtract(accumulator2, Avx2.CompareEqual(searchMask, LoadVector256(ref searchSpace, (offset + 64))));
accumulator3 = Avx2.Subtract(accumulator3, Avx2.CompareEqual(searchMask, LoadVector256(ref searchSpace, (offset + 96))));
offset += 128;
} while (loopIndex < loopLimit);
lengthToExamine -= (128 * loopLimit);
sum = Avx2.Add(sum, Avx2.SumAbsoluteDifferences(accumulator0.AsByte(), Vector256<byte>.Zero).AsInt64());
sum = Avx2.Add(sum, Avx2.SumAbsoluteDifferences(accumulator1.AsByte(), Vector256<byte>.Zero).AsInt64());
sum = Avx2.Add(sum, Avx2.SumAbsoluteDifferences(accumulator2.AsByte(), Vector256<byte>.Zero).AsInt64());
sum = Avx2.Add(sum, Avx2.SumAbsoluteDifferences(accumulator3.AsByte(), Vector256<byte>.Zero).AsInt64());
} while (127 < lengthToExamine);
var sumX = Avx2.ExtractVector128(sum, 0);
var sumY = Avx2.ExtractVector128(sum, 1);
var sumZ = Sse2.Add(sumX, sumY);
result += (sumZ.GetElement(0) + sumZ.GetElement(1));
if (31 < lengthToExamine) {
var sum = Vector256<long>.Zero;
do {
sum = Avx2.Add(sum, Avx2.SumAbsoluteDifferences(Avx2.Subtract(Vector256<byte>.Zero, Avx2.CompareEqual(searchMask, LoadVector256(ref searchSpace, offset))).AsByte(), Vector256<byte>.Zero).AsInt64());
lengthToExamine -= 32;
offset += 32;
} while (31 < lengthToExamine);
var sumX = Avx2.ExtractVector128(sum, 0);
var sumY = Avx2.ExtractVector128(sum, 1);
var sumZ = Sse2.Add(sumX, sumY);
result += (sumZ.GetElement(0) + sumZ.GetElement(1));
if (offset < ((nuint)(uint)length)) {
lengthToExamine = (((nuint)(uint)length) - offset);
goto SequentialScan;
else if (Sse2.IsSupported) {
lengthToExamine = GetByteVector128SpanLength(offset, length);
var searchMask = Vector128.Create(value);
if (63 < lengthToExamine) {
var sum = Vector128<long>.Zero;
do {
var accumulator0 = Vector128<byte>.Zero;
var accumulator1 = Vector128<byte>.Zero;
var accumulator2 = Vector128<byte>.Zero;
var accumulator3 = Vector128<byte>.Zero;
var loopIndex = ((nuint)0);
var loopLimit = Math.Min(255, (lengthToExamine / 64));
do {
accumulator0 = Sse2.Subtract(accumulator0, Sse2.CompareEqual(searchMask, LoadVector128(ref searchSpace, offset)));
accumulator1 = Sse2.Subtract(accumulator1, Sse2.CompareEqual(searchMask, LoadVector128(ref searchSpace, (offset + 16))));
accumulator2 = Sse2.Subtract(accumulator2, Sse2.CompareEqual(searchMask, LoadVector128(ref searchSpace, (offset + 32))));
accumulator3 = Sse2.Subtract(accumulator3, Sse2.CompareEqual(searchMask, LoadVector128(ref searchSpace, (offset + 48))));
offset += 64;
} while (loopIndex < loopLimit);
lengthToExamine -= (64 * loopLimit);
sum = Sse2.Add(sum, Sse2.SumAbsoluteDifferences(accumulator0.AsByte(), Vector128<byte>.Zero).AsInt64());
sum = Sse2.Add(sum, Sse2.SumAbsoluteDifferences(accumulator1.AsByte(), Vector128<byte>.Zero).AsInt64());
sum = Sse2.Add(sum, Sse2.SumAbsoluteDifferences(accumulator2.AsByte(), Vector128<byte>.Zero).AsInt64());
sum = Sse2.Add(sum, Sse2.SumAbsoluteDifferences(accumulator3.AsByte(), Vector128<byte>.Zero).AsInt64());
} while (63 < lengthToExamine);
result += (sum.GetElement(0) + sum.GetElement(1));
if (15 < lengthToExamine) {
var sum = Vector128<long>.Zero;
do {
sum = Sse2.Add(sum, Sse2.SumAbsoluteDifferences(Sse2.Subtract(Vector128<byte>.Zero, Sse2.CompareEqual(searchMask, LoadVector128(ref searchSpace, offset))).AsByte(), Vector128<byte>.Zero).AsInt64());
lengthToExamine -= 16;
offset += 16;
} while (15 < lengthToExamine);
result += (sum.GetElement(0) + sum.GetElement(1));
if (offset < ((nuint)(uint)length)) {
lengthToExamine = (((nuint)(uint)length) - offset);
goto SequentialScan;
return ((int)result);
private static unsafe int OccurrencesOf(ref char searchSpace, char value, int length) {
var lengthToExamine = ((nint)length);
var offset = ((nint)0);
var result = 0L;
if (0 != ((int)Unsafe.AsPointer(ref searchSpace) & 1)) { }
else if (Sse2.IsSupported || Avx2.IsSupported) {
if (15 < length) {
lengthToExamine = UnalignedCountVector128(ref searchSpace);
while (3 < lengthToExamine) {
ref char current = ref Unsafe.Add(ref searchSpace, offset);
if (value == current) {
if (value == Unsafe.Add(ref current, 1)) {
if (value == Unsafe.Add(ref current, 2)) {
if (value == Unsafe.Add(ref current, 3)) {
lengthToExamine -= 4;
offset += 4;
while (0 < lengthToExamine) {
if (value == Unsafe.Add(ref searchSpace, offset)) {
if (offset < length) {
if (Avx2.IsSupported) {
if (0 != (((nint)Unsafe.AsPointer(ref Unsafe.Add(ref searchSpace, offset))) & (Vector256<byte>.Count - 1))) {
var sum = Sse2.SumAbsoluteDifferences(Sse2.Subtract(Vector128<ushort>.Zero, Sse2.CompareEqual(Vector128.Create(value), LoadVector128(ref searchSpace, offset))).AsByte(), Vector128<byte>.Zero).AsInt64();
offset += 8;
result += (sum.GetElement(0) + sum.GetElement(1));
lengthToExamine = GetCharVector256SpanLength(offset, length);
var searchMask = Vector256.Create(value);
if (63 < lengthToExamine) {
var sum = Vector256<long>.Zero;
do {
var accumulator0 = Vector256<ushort>.Zero;
var accumulator1 = Vector256<ushort>.Zero;
var accumulator2 = Vector256<ushort>.Zero;
var accumulator3 = Vector256<ushort>.Zero;
var loopIndex = 0;
var loopLimit = Math.Min(255, (lengthToExamine / 64));
do {
accumulator0 = Avx2.Subtract(accumulator0, Avx2.CompareEqual(searchMask, LoadVector256(ref searchSpace, offset)));
accumulator1 = Avx2.Subtract(accumulator1, Avx2.CompareEqual(searchMask, LoadVector256(ref searchSpace, (offset + 16))));
accumulator2 = Avx2.Subtract(accumulator2, Avx2.CompareEqual(searchMask, LoadVector256(ref searchSpace, (offset + 32))));
accumulator3 = Avx2.Subtract(accumulator3, Avx2.CompareEqual(searchMask, LoadVector256(ref searchSpace, (offset + 48))));
offset += 64;
} while (loopIndex < loopLimit);
lengthToExamine -= (64 * loopLimit);
sum = Avx2.Add(sum, Avx2.SumAbsoluteDifferences(accumulator0.AsByte(), Vector256<byte>.Zero).AsInt64());
sum = Avx2.Add(sum, Avx2.SumAbsoluteDifferences(accumulator1.AsByte(), Vector256<byte>.Zero).AsInt64());
sum = Avx2.Add(sum, Avx2.SumAbsoluteDifferences(accumulator2.AsByte(), Vector256<byte>.Zero).AsInt64());
sum = Avx2.Add(sum, Avx2.SumAbsoluteDifferences(accumulator3.AsByte(), Vector256<byte>.Zero).AsInt64());
} while (63 < lengthToExamine);
var sumX = Avx2.ExtractVector128(sum, 0);
var sumY = Avx2.ExtractVector128(sum, 1);
var sumZ = Sse2.Add(sumX, sumY);
result += (sumZ.GetElement(0) + sumZ.GetElement(1));
if (15 < lengthToExamine) {
var sum = Vector256<long>.Zero;
do {
sum = Avx2.Add(sum, Avx2.SumAbsoluteDifferences(Avx2.Subtract(Vector256<ushort>.Zero, Avx2.CompareEqual(searchMask, LoadVector256(ref searchSpace, offset))).AsByte(), Vector256<byte>.Zero).AsInt64());
lengthToExamine -= 16;
offset += 16;
} while (15 < lengthToExamine);
var sumX = Avx2.ExtractVector128(sum, 0);
var sumY = Avx2.ExtractVector128(sum, 1);
var sumZ = Sse2.Add(sumX, sumY);
result += (sumZ.GetElement(0) + sumZ.GetElement(1));
if (offset < length) {
lengthToExamine = (length - offset);
goto SequentialScan;
else if (Sse2.IsSupported) {
lengthToExamine = GetCharVector128SpanLength(offset, length);
var searchMask = Vector128.Create(value);
if (31 < lengthToExamine) {
var sum = Vector128<long>.Zero;
do {
var accumulator0 = Vector128<ushort>.Zero;
var accumulator1 = Vector128<ushort>.Zero;
var accumulator2 = Vector128<ushort>.Zero;
var accumulator3 = Vector128<ushort>.Zero;
var loopIndex = 0;
var loopLimit = Math.Min(255, (lengthToExamine / 32));
do {
accumulator0 = Sse2.Subtract(accumulator0, Sse2.CompareEqual(searchMask, LoadVector128(ref searchSpace, offset)));
accumulator1 = Sse2.Subtract(accumulator1, Sse2.CompareEqual(searchMask, LoadVector128(ref searchSpace, (offset + 8))));
accumulator2 = Sse2.Subtract(accumulator2, Sse2.CompareEqual(searchMask, LoadVector128(ref searchSpace, (offset + 16))));
accumulator3 = Sse2.Subtract(accumulator3, Sse2.CompareEqual(searchMask, LoadVector128(ref searchSpace, (offset + 24))));
offset += 32;
} while (loopIndex < loopLimit);
lengthToExamine -= (32 * loopLimit);
sum = Sse2.Add(sum, Sse2.SumAbsoluteDifferences(accumulator0.AsByte(), Vector128<byte>.Zero).AsInt64());
sum = Sse2.Add(sum, Sse2.SumAbsoluteDifferences(accumulator1.AsByte(), Vector128<byte>.Zero).AsInt64());
sum = Sse2.Add(sum, Sse2.SumAbsoluteDifferences(accumulator2.AsByte(), Vector128<byte>.Zero).AsInt64());
sum = Sse2.Add(sum, Sse2.SumAbsoluteDifferences(accumulator3.AsByte(), Vector128<byte>.Zero).AsInt64());
} while (31 < lengthToExamine);
result += (sum.GetElement(0) + sum.GetElement(1));
if (7 < lengthToExamine) {
var sum = Vector128<long>.Zero;
do {
sum = Sse2.Add(sum, Sse2.SumAbsoluteDifferences(Sse2.Subtract(Vector128<ushort>.Zero, Sse2.CompareEqual(searchMask, LoadVector128(ref searchSpace, offset))).AsByte(), Vector128<byte>.Zero).AsInt64());
lengthToExamine -= 8;
offset += 8;
} while (7 < lengthToExamine);
result += (sum.GetElement(0) + sum.GetElement(1));
if (offset < length) {
lengthToExamine = (length - offset);
goto SequentialScan;
return ((int)result);
private static unsafe nuint UnalignedCountVector128(ref byte searchSpace) {
nint unaligned = ((nint)Unsafe.AsPointer(ref searchSpace) & (Vector128<byte>.Count - 1));
return ((nuint)(uint)((Vector128<byte>.Count - unaligned) & (Vector128<byte>.Count - 1)));
private static unsafe nint UnalignedCountVector128(ref char searchSpace) {
const int ElementsPerByte = (sizeof(ushort) / sizeof(byte));
return ((nint)(uint)(-(int)Unsafe.AsPointer(ref searchSpace) / ElementsPerByte) & (Vector128<ushort>.Count - 1));
public static int OccurrencesOf(this ReadOnlySpan<byte> span, byte value) =>
length: span.Length,
searchSpace: ref MemoryMarshal.GetReference(span),
value: value
public static int OccurrencesOf(this Span<byte> span, byte value) =>
public static int OccurrencesOf(this ReadOnlySpan<char> span, char value) =>
length: span.Length,
searchSpace: ref MemoryMarshal.GetReference(span),
value: value
public static int OccurrencesOf(this Span<char> span, char value) =>
Upvotes: 4
Reputation: 366094
(AVX2 C intrinsics implementation of the below idea, in case a concrete example helps: How to count character occurrences using SIMD)
In asm, you want pcmpeqb
to produce a vector of 0 or 0xFF. Treated as signed integers, that's 0/-1.
Then use the compare-result as integers values with psubb
to add 0 / 1 to the counter for that element. (Subtract -1 = add +1)
That can overflows after 256 iterations, so sometime before that, use psadbw
against _mm_setzero_si128()
to horizontally sum those unsigned bytes (without overlow) into 64-bit integers (one 64-bit integer per group of 8 bytes). Then paddq
to accumulate 64-bit totals.
Accumulating before you overflow can be done with a nested loop, or just at the end of a regular unrolled loop. psadbw
is fast (because it's a key building block for video encoding motion-search), so it's not bad to just accumulate every 4 compares, or even every 1 and skip the psubb
See Agner Fog's optimization guides for more details on x86. According to his instruction tables, psadbw xmm
/ vpsadbw ymm
runs at 1 vector per clock cycle on Skylake, with 3 cycle latency. (Only 1 uop of front-end bandwidth.) All the instructions mentioned above are also single-uop, and run on more than one port (so don't necessarily conflict with each other for throughput). Their 128-bit versions only require SSE2.
If you really only have one vector at a time to count, and aren't looping over memory, then probably pcmpeqb
/ psadbw
/ pshufd
(copy high half to low) / paddd
/ movd eax, xmm0
gives you 255 * number of matches in an integer register. One extra vector instruction (like subtract from zero, or AND with 1, or pabsb
(absolute value) would remove the x255 scale factor.
IDK how to write that in C# SIMD, but you definitely do not want a dot-product! Unpack and convert to FP would be about 4x slower than the above, just from the fact that a fixed-width vector holds 4x more bytes than floats, and dpps
) is not fast. 4 uops, and one per 1.5 cycle throughput on Skylake. If you do have to horizontal-sum something other than unsigned bytes, see Fastest way to do horizontal SSE vector sum (or other reduction) (my answer also include integer).
Or if Vector.Dot
uses pmaddubsw
/ pmaddwd
for integer vectors, then that might not be as bad, but doing a multi-step horizontal sum for each vector of compare results is just bad compared to psadbw
, or especially to byte accumulators that you only horizontal sum occasionally.
Or if C# optimizes out any actual multiplying with a constant vector of 1
. Anyway, the first part of this answer is the code you want the CPU to be running. Make that happen however you like using whatever source code gets it to happen.
Upvotes: 4
Reputation: 321
Here a fast SSE2 implementation in C:
size_t memcount_sse2(const void *s, int c, size_t n) {
__m128i cv = _mm_set1_epi8(c), sum = _mm_setzero_si128(), acr0,acr1,acr2,acr3;
const char *p,*pe;
for(p = s; p != (char *)s+(n- (n % (252*16)));) {
for(acr0 = acr1 = acr2 = acr3 = _mm_setzero_si128(),pe = p+252*16; p != pe; p += 64) {
acr0 = _mm_add_epi8(acr0, _mm_cmpeq_epi8(cv, _mm_loadu_si128((const __m128i *)p)));
acr1 = _mm_add_epi8(acr1, _mm_cmpeq_epi8(cv, _mm_loadu_si128((const __m128i *)(p+16))));
acr2 = _mm_add_epi8(acr2, _mm_cmpeq_epi8(cv, _mm_loadu_si128((const __m128i *)(p+32))));
acr3 = _mm_add_epi8(acr3, _mm_cmpeq_epi8(cv, _mm_loadu_si128((const __m128i *)(p+48))));
sum = _mm_add_epi64(sum, _mm_sad_epu8(_mm_sub_epi8(_mm_setzero_si128(), acr0), _mm_setzero_si128()));
sum = _mm_add_epi64(sum, _mm_sad_epu8(_mm_sub_epi8(_mm_setzero_si128(), acr1), _mm_setzero_si128()));
sum = _mm_add_epi64(sum, _mm_sad_epu8(_mm_sub_epi8(_mm_setzero_si128(), acr2), _mm_setzero_si128()));
sum = _mm_add_epi64(sum, _mm_sad_epu8(_mm_sub_epi8(_mm_setzero_si128(), acr3), _mm_setzero_si128()));
// may require SSE4, rewrite this part for actual SSE2.
size_t count = _mm_extract_epi64(sum, 0) + _mm_extract_epi64(sum, 1);
// scalar cleanup. Could be optimized.
while(p != (char *)s + n) count += *p++ == c;
return count;
and see: for and avx2 implementation.
Upvotes: 2