Reputation: 9347
I want to compare two vectors of 16 bytes and get every matching index. A small example to illustrate what I want:
fn get_matching_idx(arr1: &[u8], arr2: &[u8]) {
let vec1 = u8x16::load_aligned(arr1);
let vec2 = u8x16::load_aligned(arr2);
let matches = vec1.eq(vec2);
for i in 0..16 {
if matches.extract_unchecked(i) {
// Do something with the index
}
}
}
Ideally, I'd just want to "Do something" for the set indices, rather than checking every single one (there will be a low number of matches).
Is there a way to get the matching indices using intrinsics, rather than iterating through the whole vector? With gcc for example, I could use _mm_movemask_epi8 to bit pack the vector and then repeated applications of __builtin_clz
to get the index of the first set bit (which is more performant for sparse numbers which I would have). Alternatively, I could have a lookup table which did the right thing for each nibble in my bit-packed integer (e.g. the first answer here).
Is there an equivalent of these instructions in rust?
I'm compiling for an Intel x86-64 processor and cross platform support is not a requirement.
NOTE: I'd prefer a solution in native (safe) rust, but this is not a hard requirement. I am fine writing unsafe rust, or even using some sort of FFI to link to the aforementioned methods.
Upvotes: 5
Views: 1011
Reputation: 9347
std::arch
contains an exhaustive set of intrinsic operations. This can be done using core::arch
and std::simd
as follows:
use std::arch::x86_64::{self, __m128i};
use std::simd::{u8x16, FromBits};
unsafe fn get_matching_idx(arr1: &[u8], arr2: &[u8]) -> u32 {
let vec1 = __m128i::from_bits(u8x16::load_aligned_unchecked(arr1));
let vec2 = __m128i::from_bits(u8x16::load_aligned_unchecked(arr2));
return x86_64::_mm_movemask_epi8(x86_64::_mm_cmpeq_epi8(vec1, vec2)) as u32;
}
fn main() {
// let arr1 = ...
// let arr2 = ...
unsafe {
let mut mask = get_matching_idx(arr1, arr2);
}
let mut delta_i = 0;
// This assumes a little endian machine (note it counts trailing 0s)
while group_mask > 0 {
let tz = x86_64::_mm_tzcnt_32(mask);
let i = tz + delta_i;
// Do something...
group_mask >>= tz + 1;
delta_i += tz + 1;
}
}
Upvotes: 1