user1413793
user1413793

Reputation: 9347

Rust get index of true bytes in SIMD vector

I want to compare two vectors of 16 bytes and get every matching index. A small example to illustrate what I want:

fn get_matching_idx(arr1: &[u8], arr2: &[u8]) {
    let vec1 = u8x16::load_aligned(arr1);    
    let vec2 = u8x16::load_aligned(arr2);
    let matches = vec1.eq(vec2);
    for i in 0..16 {
        if matches.extract_unchecked(i) {
            // Do something with the index
        }
    }
}

Ideally, I'd just want to "Do something" for the set indices, rather than checking every single one (there will be a low number of matches).

Is there a way to get the matching indices using intrinsics, rather than iterating through the whole vector? With gcc for example, I could use _mm_movemask_epi8 to bit pack the vector and then repeated applications of __builtin_clz to get the index of the first set bit (which is more performant for sparse numbers which I would have). Alternatively, I could have a lookup table which did the right thing for each nibble in my bit-packed integer (e.g. the first answer here).

Is there an equivalent of these instructions in rust?

I'm compiling for an Intel x86-64 processor and cross platform support is not a requirement.

NOTE: I'd prefer a solution in native (safe) rust, but this is not a hard requirement. I am fine writing unsafe rust, or even using some sort of FFI to link to the aforementioned methods.

Upvotes: 5

Views: 1011

Answers (1)

user1413793
user1413793

Reputation: 9347

std::arch contains an exhaustive set of intrinsic operations. This can be done using core::arch and std::simd as follows:

use std::arch::x86_64::{self, __m128i};
use std::simd::{u8x16, FromBits};

unsafe fn get_matching_idx(arr1: &[u8], arr2: &[u8]) -> u32 {
    let vec1 = __m128i::from_bits(u8x16::load_aligned_unchecked(arr1));
    let vec2 = __m128i::from_bits(u8x16::load_aligned_unchecked(arr2));
    return x86_64::_mm_movemask_epi8(x86_64::_mm_cmpeq_epi8(vec1, vec2)) as u32;
}

fn main() {
    // let arr1 = ...
    // let arr2 = ...

    unsafe {
        let mut mask = get_matching_idx(arr1, arr2);
    }
    let mut delta_i = 0;
    // This assumes a little endian machine (note it counts trailing 0s)
    while group_mask > 0 {
        let tz = x86_64::_mm_tzcnt_32(mask);
        let i = tz + delta_i;
        // Do something...
        group_mask >>= tz + 1;
        delta_i += tz + 1;
    }
}

Upvotes: 1

Related Questions