user9604623
user9604623

Reputation:

How to simplify mathematical formulas with rust macros?

I must admit I'm a bit lost with macros. I want to build a macro that does the following task and I'm not sure how to do it. I want to perform a scalar product of two arrays, say x and y, which have the same length N. The result I want to compute is of the form:

z = sum_{i=0}^{N-1} x[i] * y[i].

x is const which elements are 0, 1, or -1 which are known at compile time, while y's elements are determined at runtime. Because of the structure of x, many computations are useless (terms multiplied by 0 can be removed from the sum, and multiplications of the form 1 * y[i], -1 * y[i] can be transformed into y[i], -y[i] respectively).

As an example if x = [-1, 1, 0], the scalar product above would be

z=-1 * y[0] + 1 * y[1] + 0 * y[2]

To speed up my computation I can unroll the loop by hand and rewrite the whole thing without x[i], and I could hard code the above formula as

z = -y[0] + y[1]

But this procedure is not elegant, error prone and very tedious when N becomes large.

I'm pretty sure I can do that with a macro, but I don't know where to start (the different books I read are not going too deep into macros and I'm stuck)...

Would anyone of you have any idea how to (if it is possible) this problem using macros?

Thank you in advance for your help!

Edit: As pointed out in many of the answers, the compiler is smart enough to remove optimize the loop in the case of integers. I am not only using integers but also floats (the x array is i32s, but in general y is f64s), so the compiler is not smart enough (and rightfully so) to optimize the loop. The following piece of code gives the following asm.

const X: [i32; 8] = [0, 1, -1, 0, 0, 1, 0, -1];

pub fn dot_x(y: [f64; 8]) -> f64 {
    X.iter().zip(y.iter()).map(|(i, j)| (*i as f64) * j).sum()
}
playground::dot_x:
    xorpd   %xmm0, %xmm0
    movsd   (%rdi), %xmm1
    mulsd   %xmm0, %xmm1
    addsd   %xmm0, %xmm1
    addsd   8(%rdi), %xmm1
    subsd   16(%rdi), %xmm1
    movupd  24(%rdi), %xmm2
    xorpd   %xmm3, %xmm3
    mulpd   %xmm2, %xmm3
    addsd   %xmm3, %xmm1
    unpckhpd    %xmm3, %xmm3
    addsd   %xmm1, %xmm3
    addsd   40(%rdi), %xmm3
    mulsd   48(%rdi), %xmm0
    addsd   %xmm3, %xmm0
    subsd   56(%rdi), %xmm0
    retq

Upvotes: 7

Views: 1630

Answers (4)

lu_zero
lu_zero

Reputation: 998

If you can spare an #[inline(always)] probably using an explicit filter_map() should be enough to have the compiler do what you want.

Upvotes: 2

Sven Marnach
Sven Marnach

Reputation: 602365

In many cases, the optimisation stage of the compiler will take care of this for you. To give an example, this function definition

const X: [i32; 8] = [0, 1, -1, 0, 0, 1, 0, -1];

pub fn dot_x(y: [i32; 8]) -> i32 {
    X.iter().zip(y.iter()).map(|(i, j)| i * j).sum()
}

results in this assembly output on x86_64:

playground::dot_x:
    mov eax, dword ptr [rdi + 4]
    sub eax, dword ptr [rdi + 8]
    add eax, dword ptr [rdi + 20]
    sub eax, dword ptr [rdi + 28]
    ret

You won't be able to get any more optimised version than this, so simply writing the code in a naïve way is the best solution. Whether the compiler will unroll the loop for longer vectors is unclear, and it may change with compiler versions.

For floating-point numbers, the compiler is not normally able to perform all the optimisations above, since the numbers in y are not guaranteed to be finite – they could also be NaN, inf or -inf. For this reason, multiplying with 0.0 is not guaranteed to result in 0.0 again, so the compiler needs to keep the multiplication instructions in the code. You can explicitly allow it to assume all numbers are finite, though, by using the fmul_fast() instrinsic function:

#![feature(core_intrinsics)]
use std::intrinsics::fmul_fast;

const X: [i32; 8] = [0, 1, -1, 0, 0, 1, 0, -1];

pub fn dot_x(y: [f64; 8]) -> f64 {
    X.iter().zip(y.iter()).map(|(i, j)| unsafe { fmul_fast(*i as f64, *j) }).sum()
}

This results in the following assembly code:

playground::dot_x: # @playground::dot_x
# %bb.0:
    xorpd   xmm1, xmm1
    movsd   xmm0, qword ptr [rdi + 8] # xmm0 = mem[0],zero
    addsd   xmm0, xmm1
    subsd   xmm0, qword ptr [rdi + 16]
    addsd   xmm0, xmm1
    addsd   xmm0, qword ptr [rdi + 40]
    addsd   xmm0, xmm1
    subsd   xmm0, qword ptr [rdi + 56]
    ret

This still redundantly adds zeros between the steps, but I would not expect this to result in any measurable overhead for realistic CFD simulations, since such simulations tend to be limited by memory bandwidth rather than CPU. If you want to avoid these additions as well, you need to use fadd_fast() for the additions to allow the compiler to optimise further:

#![feature(core_intrinsics)]
use std::intrinsics::{fadd_fast, fmul_fast};

const X: [i32; 8] = [0, 1, -1, 0, 0, 1, 0, -1];

pub fn dot_x(y: [f64; 8]) -> f64 {
    let mut result = 0.0;
    for (&i, &j) in X.iter().zip(y.iter()) {
        unsafe { result = fadd_fast(result, fmul_fast(i as f64, j)); }
    }
    result
}

This results in the following assembly code:

playground::dot_x: # @playground::dot_x
# %bb.0:
    movsd   xmm0, qword ptr [rdi + 8] # xmm0 = mem[0],zero
    subsd   xmm0, qword ptr [rdi + 16]
    addsd   xmm0, qword ptr [rdi + 40]
    subsd   xmm0, qword ptr [rdi + 56]
    ret

As with all optmisations, you should start with the most readable and maintainable version of the code. If performance becomes an issue, you should profile your code and find the bottlenecks. As the next step, try to improve the fundamental approach, e.g. by using an algorithm with a better asymptotical complexity. Only then should you turn to micro-optimisations like the one you suggested in the question.

Upvotes: 2

Mara
Mara

Reputation: 1015

First of all, a (proc) macro can simply not look inside your array x. All it gets are the tokens you pass it, without any context. If you want it to know about the values (0, 1, -1), you need to pass those directly to your macro:

let result = your_macro!(y, -1, 0, 1, -1);

But you don't really need a macro for this. The compiler optimizes a lot, as also shown in the other answers. However, it will not, as you already mention in your edit, optimize away 0.0 * x[i], as the result of that is not always 0.0. (It could be -0.0 or NaN for example.) What we can do here, is simply help the optimizer a bit by using a match or if, to make sure it does nothing for the 0.0 * y case:

const X: [i32; 8] = [0, -1, 0, 0, 0, 0, 1, 0];

fn foobar(y: [f64; 8]) -> f64 {
    let mut sum = 0.0;
    for (&x, &y) in X.iter().zip(&y) {
        if x != 0 {
            sum += x as f64 * y;
        }
    }
    sum
}

In release mode, the loop is unrolled and the values of X inlined, resulting in most iterations being thrown away as they don't do anything. The only thing left in the resulting binary (on x86_64), is:

foobar:
 xorpd   xmm0, xmm0
 subsd   xmm0, qword, ptr, [rdi, +, 8]
 addsd   xmm0, qword, ptr, [rdi, +, 48]
 ret

(As suggested by @lu-zero, this can also be done using filter_map. That will look like this: X.iter().zip(&y).filter_map(|(&x, &y)| match x { 0 => None, _ => Some(x as f64 * y) }).sum(), and gives the exact same generated assembly. Or even without a match, by using filter and map separately: .filter(|(&x, _)| x != 0).map(|(&x, &y)| x as f64 * y).sum().)

Pretty good! However, this function calculates 0.0 - y[1] + y[6], since sum started at 0.0 and we only subtract and add things to it. The optimizer is again not willing to optimize away a 0.0. We can help it a bit more by not starting at 0.0, but starting with None:

fn foobar(y: [f64; 8]) -> f64 {
    let mut sum = None;
    for (&x, &y) in X.iter().zip(&y) {
        if x != 0 {
            let p = x as f64 * y;
            sum = Some(sum.map_or(p, |s| s + p));
        }
    }
    sum.unwrap_or(0.0)
}

This results in:

foobar:
 movsd   xmm0, qword, ptr, [rdi, +, 48]
 subsd   xmm0, qword, ptr, [rdi, +, 8]
 ret

Which simply does y[6] - y[1]. Bingo!

Upvotes: 7

asky
asky

Reputation: 1742

You may be able to achieve your goal with a macro that returns a function.

First, write this function without a macro. This one takes a fixed number of parameters.

fn main() {
    println!("Hello, world!");
    let func = gen_sum([1,2,3]);
    println!("{}", func([4,5,6])) // 1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32
}

fn gen_sum(xs: [i32; 3]) -> impl Fn([i32;3]) -> i32 {
    move |ys| ys[0]*xs[0] + ys[1]*xs[1] + ys[2]*xs[2]
}

Now, completely rewrite it because the prior design doesn't work well as a macro. We had to give up on fixed sized arrays, as macros appear unable to allocate fixed-sized arrays.

Rust Playground

fn main() {
    let func = gen_sum!(1,2,3);
    println!("{}", func(vec![4,5,6])) // 1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32
}

#[macro_export]
macro_rules! gen_sum {
    ( $( $x:expr ),* ) => {
        {
            let mut xs = Vec::new();
            $(
                xs.push($x);
            )*
            move |ys:Vec<i32>| {
                if xs.len() != ys.len() {
                    panic!("lengths don't match")
                }
                let mut total = 0;
                for i in 0 as usize .. xs.len() {
                    total += xs[i] * ys[i];
                }
                total
            } 
        }
    };
}

What does this do/What should it do

At compile time, it generates a lambda. This lambda accepts a list of numbers and multiplies it by a vec that was generated at compile time. I don't think this was exactly what you were after, as it does not optimize away zeroes at compile time. You could optimize away zeroes at compile time, but you would necessarily incur some cost at run-time by having to check where the zeroes were in x to determine which elements to multiply by in y. You could even make this lookup process in constant time using a hashset. It's still probably not worth it in general (where I presume 0 is not all that common). Computers are better at doing one thing that's "inefficient" than they are at detecting that the thing they're about to do is "inefficient" then skipping that thing. This abstraction breaks down when a significant portion of the operations they do are "inefficient"

Follow-up

Was that worth it? Does it improve run times? I didn't measure, but it seems like understanding and maintaining the macro I wrote isn't worth it compared to just using a function. Writing a macro that does the zero optimization you talked about would probably be even less pleasant.

Upvotes: 3

Related Questions