McKayla
McKayla

Reputation: 6959

Comparing functions for equality in Rust

I have a function which takes a number as an argument, and then returns a function based on the number. Depending on many different things, it might return any of ~50 functions, and the cases for which one it should return get pretty complicated. As such, I want to build some tests to make sure the proper functions are being returned. What I have so far looks roughly like this.

fn pick_a_function(decider: u32) -> fn(&mut SomeStruct) {
    match decider {
        1 => add,
        2 => sub,
        _ => zero,
    }
}

fn add(x: &mut SomeStruct) {
    x.a += x.b;
}

fn sub(x: &mut SomeStruct) {
    x.a -= x.b;
}

fn zero(_x: &mut SomeStruct) {
    x.a = 0;
}

fn main() {
    let mut x = SomeStruct { a: 2, b: 3 };
    pick_a_function(1)(&mut x);

    println!("2 + 3 = {}", x.a);
}

#[cfg(test)]
mod tests {
    use super::*;

    fn picks_correct_function() {
        assert_eq!(pick_a_function(1), add);
    }
}

The problem is that the functions don't seem to implement the Eq or PartialEq traits, so assert_eq! just says that it can't compare them. What options do I have for comparing the returned function to the correct function?

Upvotes: 3

Views: 1607

Answers (2)

McKayla
McKayla

Reputation: 6959

So it turns of that functions in Rust actually do implement PartialEq as long as there is not a lifetime attached, and as long as the function takes less than 10 arguments. This restriction is because each form of function signature has to have the traits implemented directly, because the compiler considers all of them to be completely unrelated types.

The functions I was returning took a mutable reference to a struct, which implicitly gives the function a lifetime, so they no longer had a type signature which implemented PartialEq. All that rust really does internally to compare function equality though is cast both of them to pointers and then compare, so we can actually just do the same thing.

#[cfg(test)]
mod tests {
    use super::*;

    fn picks_correct_function() {
        assert_eq!(
            pick_a_function(1) as usize,
            add as usize
        );
    }
}

Upvotes: 3

user11667416
user11667416

Reputation:

You should compare the result instead of the function,for example:

#[cfg(test)]
mod tests {
    use super::*;

    fn picks_correct_function() {
      let add_picked = pick_a_function(1);
      assert_eq!(add_picked(1,2), add(1,2));
    }
}

Or in more complex scenarios you can compare the inputs making a function that takes one parameter and another that takes two,try to call any of them and see if you get a compiler error.

Upvotes: 0

Related Questions