Poperton
Poperton

Reputation: 2127

How to auto-vectorize (SIMD) a modular multiplication in Rust

I'm trying to optimize a code that has modular multiplication, to use SIMD auto-vectorization. That is, I don't want to use any libraries, the compiler should do the job. Here's the smalles verifiable example I could get:

#[inline(always)]
fn mod_mul64(
    a: u64,
    b: u64,
    modulus: u64,
) -> u64 {
    ((a as u128 * b as u128) % modulus as u128) as u64
}

pub fn mul(a: &mut [u64], b: &[u64], modulo: u64){
    for _ in (0..1000).step_by(4) {
        a[0] = mod_mul64(b[0], a[7], modulo);
        a[1] = mod_mul64(b[1], a[6], modulo);
        a[2] = mod_mul64(b[2], a[5], modulo);
        a[3] = mod_mul64(b[3], a[4], modulo);
        a[4] = mod_mul64(b[4], a[3], modulo);
        a[5] = mod_mul64(b[5], a[2], modulo);
        a[6] = mod_mul64(b[6], a[1], modulo);
        a[7] = mod_mul64(b[7], a[0], modulo);
    }
}

#[allow(unused)]
pub fn main() {
    let a: &mut[u64] = todo!();
    let b: &[u64] = todo!();
    let modulo = todo!();
    mul(a, b, modulo);
    println!("a: {:?}", a);
}

As seen on https://godbolt.org/z/h8zfadz3d even when optimizations are turned on and the target CPU is native, there's no SIMD instructions, which shoud start with v for vector.

I understand that this mod_mul64 implementation may not be SIMD-friendly. What should be an easy way to modify it so it gets SIMD-ed automatically?

Upvotes: 2

Views: 122

Answers (1)

RedRam
RedRam

Reputation: 36

Your current code seems wrong and just modifies the first 8 numbers a bunch. I'll assume you're trying to vectorize general (a * b) % c math. There are a lot of mod/rem implementations but was able to get auto vectorization to work but only for 32 bit values with overflow with a constant modulus. Or for u64 with a power of two modulus.

#![allow(dead_code)]

// `cargo asm --lib --native mul32_const_asm_test 0`
// `cargo asm --lib --native mul32_const_asm_test_aligned 0`
// `cargo asm --lib --native mul64_const_asm_test_aligned 0`
// `cargo asm --lib --native mul_many 0`
// `cargo asm --lib --native mul_many_aligned_u64 0`

// MARK: Const modulus asm tests

#[no_mangle]
pub const fn mul32_const_asm_test(a: &[u32; 8], b: &[u32; 8]) -> [u32; 8] {
    let some_random_number = 2232673653;
    mul_many(a, b, some_random_number)
}

#[no_mangle]
pub const fn mul32_const_asm_test_aligned(
    a: &Aligned64<[u32; 8]>,
    b: &Aligned64<[u32; 8]>,
) -> Aligned64<[u32; 8]> {
    let some_random_number = 2232673653;
    mul_many_aligned(a, b, some_random_number)
}

#[no_mangle]
pub const fn mul64_const_asm_test_aligned(
    a: &Aligned64<[u64; 4]>,
    b: &Aligned64<[u64; 4]>,
) -> Aligned64<[u64; 4]> {
    let some_random_number = 2232673653;
    mul_many_aligned_u64(a, b, some_random_number)
}

// MARK: Non const asm Tests

// NOTE: scalar asm on its own, can be vectorized if `modulo` is a constant.
#[no_mangle]
pub const fn mul_many(a: &[u32; 8], b: &[u32; 8], modulo: u32) -> [u32; 8] {
    // let func = mod_mul32_expanding; // not vectorized
    // let func = mod_mul32_triple_custom; // vectorized, big
    // let func = mod_mul32_simple_custom; // vectorized
    // let func = mod_mul32_triple; // vectorized
    let func = mod_mul32_simple; // vectorized

    let mut out = [0; 8];
    out[0] = func(b[0], a[7], modulo);
    out[1] = func(b[1], a[6], modulo);
    out[2] = func(b[2], a[5], modulo);
    out[3] = func(b[3], a[4], modulo);
    out[4] = func(b[4], a[3], modulo);
    out[5] = func(b[5], a[2], modulo);
    out[6] = func(b[6], a[1], modulo);
    out[7] = func(b[7], a[0], modulo);
    out
}

// NOTE: scalar asm on its own, can be vectorized if `modulo` is a constant.
#[no_mangle]
pub const fn mul_many_aligned(
    a: &Aligned64<[u32; 8]>,
    b: &Aligned64<[u32; 8]>,
    modulo: u32,
) -> Aligned64<[u32; 8]> {
    // let func = mod_mul32_expanding; // not vectorized
    // let func = mod_mul32_triple_custom; // vectorized, big
    // let func = mod_mul32_simple_custom; // vectorized
    // let func = mod_mul32_triple; // vectorized
    let func = mod_mul32_simple; // vectorized

    let mut out = Aligned64([0; 8]);
    out.0[0] = func(b.0[0], a.0[7], modulo);
    out.0[1] = func(b.0[1], a.0[6], modulo);
    out.0[2] = func(b.0[2], a.0[5], modulo);
    out.0[3] = func(b.0[3], a.0[4], modulo);
    out.0[4] = func(b.0[4], a.0[3], modulo);
    out.0[5] = func(b.0[5], a.0[2], modulo);
    out.0[6] = func(b.0[6], a.0[1], modulo);
    out.0[7] = func(b.0[7], a.0[0], modulo);
    out
}

// I couldn't get this vectorized
#[no_mangle]
pub const fn mul_many_aligned_u64(
    a: &Aligned64<[u64; 4]>,
    b: &Aligned64<[u64; 4]>,
    modulo: u64,
) -> Aligned64<[u64; 4]> {
    // let func = mod_mul64_expanding; // not vectorized
    // let func = mod_mul64_simple; // not vectorized
    let func = mod_mul64_simple_custom; // surprising not vectorized

    let mut out = Aligned64([0; 4]);
    out.0[0] = func(b.0[0], a.0[3], modulo);
    out.0[1] = func(b.0[1], a.0[2], modulo);
    out.0[2] = func(b.0[2], a.0[1], modulo);
    out.0[3] = func(b.0[3], a.0[0], modulo);
    out
}

// MARK: 32 bit

/// Never overflows
#[inline(always)]
const fn mod_mul32_expanding(a: u32, b: u32, modulus: u32) -> u32 {
    ((a as u64 * b as u64) % modulus as u64) as u32
}

/// Overflows if `modulus`, and `a` and `b` are huge
#[inline(always)]
const fn mod_mul32_triple(a: u32, b: u32, modulus: u32) -> u32 {
    (a % modulus * b % modulus) % modulus
}

/// Overflows if `a * b` overflows
#[inline(always)]
const fn mod_mul32_simple(a: u32, b: u32, modulus: u32) -> u32 {
    (a * b) % modulus
}

#[inline(always)]
const fn mod_mul32_triple_custom(a: u32, b: u32, modulus: u32) -> u32 {
    rem_u32(rem_u32(a, modulus) * rem_u32(b, modulus), modulus)
}

#[inline(always)]
const fn mod_mul32_simple_custom(a: u32, b: u32, modulus: u32) -> u32 {
    rem_u32(a * b, modulus)
}

// MARK: 64 bit

/// Never overflows
#[inline(always)]
const fn mod_mul64_expanding(a: u64, b: u64, modulus: u64) -> u64 {
    ((a as u128 * b as u128) % modulus as u128) as u64
}

/// Overflows if `a * b` overflows
#[inline(always)]
const fn mod_mul64_simple(a: u64, b: u64, modulus: u64) -> u64 {
    (a * b) % modulus
}

#[inline(always)]
const fn mod_mul64_simple_custom(a: u64, b: u64, modulus: u64) -> u64 {
    rem_u64(a * b, modulus)
}

// MARK: Helpers

/// I dont think it overflows and I think gives exact same resutls as % for unsigned.
#[inline(always)]
const fn rem_u32(lhs: u32, rhs: u32) -> u32 {
    // TODO: does any of this overflow?
    lhs - (rhs * (lhs / rhs))
}

/// I dont think it overflows and I think gives exact same resutls as % for unsigned.
#[inline(always)]
const fn rem_u64(lhs: u64, rhs: u64) -> u64 {
    // TODO: does any of this overflow?
    lhs - (rhs * (lhs / rhs))
}

#[repr(align(64))]
pub struct Aligned64<T>(pub T);

impl<T> std::ops::Deref for Aligned64<T> {
    type Target = T;

    fn deref(&self) -> &Self::Target {
        &self.0
    }
}

impl<T> std::ops::DerefMut for Aligned64<T> {
    fn deref_mut(&mut self) -> &mut Self::Target {
        &mut self.0
    }
}

You should be able to manually vectorize 64 bit values using the fact that lhs % rhs = lhs - (rhs * (lhs / rhs)) for integers. (not sure of overflow tho, I don't think it can)

Some tipes for auto vectorization:

  • RUSTFLAGS="-Ctarget-cpu=native" cargo build --release
  • Align data
  • Use arrays instead of slices
  • Early assert slices length are the same
  • Don't mutate input / don't read input after mutating it
    • Try to keep dependancy chains to a minimum, parallel not serial
  • Remove panic paths
    • div or rem by zero, out of bounds
    • NonZeroU32, get_unchecked()
  • Maybe manual partial loop unrolling

Upvotes: 1

Related Questions