ryn1x
ryn1x

Reputation: 1132

How can I improve the performance of element-wise multiplication in Rust?

I will be doing element-wise multiplication on multiple vectors with 10^6+ elements. This is being flagged in profiling as one of the slowest parts of my code, so how can I improve it?

/// element-wise multiplication for vecs
pub fn vec_mul<T>(v1: &Vec<T>, v2: &Vec<T>) -> Vec<T>
where
    T: std::ops::Mul<Output = T> + Copy,
{
    if v1.len() != v2.len() {
        panic!("Cannot multiply vectors of different lengths!")
    }
    let mut out: Vec<T> = Vec::with_capacity(v1.len());
    for i in 0..(v1.len()) {
        out.push(v1[i] * v2[i]);
    }
    out
}

Upvotes: 7

Views: 2407

Answers (1)

Francis Gagn&#233;
Francis Gagn&#233;

Reputation: 65682

When you use the indexer operator on a Vec or a slice, the compiler has to check whether the index is in bounds or out of bounds.

However, when you use iterators, these bounds checks are omitted, because the iterators have been carefully written to ensure that they never read out of bounds. Furthermore, due to how borrowing works in Rust, a data structure cannot be mutated while an iterator exists over that data structure (except via that iterator itself), so it's impossible for the valid bounds to change during iteration.

Since you are iterating over two different data structures concurrently, you'll want to use the zip iterator adapter. zip stops as soon as one iterator is exhausted, so it's still relevant to validate that both vectors have the same length. zip produces an iterator of tuples, where each tuple contains the items at the same position in the two original iterators. Then you can use map to transform each tuple into the product of the two values. Finally, you'll want to collect the new iterator produced by map into a Vec which you can then return from your function. collect uses size_hint to preallocate memory for the vector using Vec::with_capacity.

/// element-wise multiplication for vecs
pub fn vec_mul<T>(v1: &[T], v2: &[T]) -> Vec<T>
where
    T: std::ops::Mul<Output = T> + Copy,
{
    if v1.len() != v2.len() {
        panic!("Cannot multiply vectors of different lengths!")
    }

    v1.iter().zip(v2).map(|(&i1, &i2)| i1 * i2).collect()
}

Note: I've changed the signature to take slices instead of references to vectors. See Why is it discouraged to accept a reference to a String (&String), Vec (&Vec), or Box (&Box) as a function argument? for more information.

Upvotes: 11

Related Questions