hpaantee
hpaantee

Reputation: 147

Function accepting ndarray and float at the same time

I would like to write a function computing the refractive index of some material as a function of wavelength. I use the ndarray crate for arrays, similar to numpy in Python. At the moment the function is implemented the following way:

fn calc_refractive_index(l: &Array1<f64>) -> Array1<f64> {
    fn refractive_index(&self, l: &Array1<f64>) -> Array1<f64> {
        let V =
            self.A[0] + self.A[1] / (1. + self.A[2] * (self.A[3] * l / self.pitch).mapv(f64::exp));
        let W =
            self.B[0] + self.B[1] / (1. + self.B[2] * (self.B[3] * l / self.pitch).mapv(f64::exp));
        (self.material.refractive_index(&l).powi(2)
            - 3. * l.powi(2) / 4. / (consts::PI * self.pitch).powi(2) * (V.powi(2) - W.powi(2)))
            .mapv(f64::sqrt)
    }
}

With the following trait implemented, to make potentiation shorter:

trait Squared<T> {
    fn powi(&self, e: i32) -> T;
}

impl Squared<Array1<f64>> for Array1<f64> {
    fn powi(&self, e: i32) -> Array1<f64> {
        self.mapv(|a| a.powi(e))
    }
}

However, I may also want to compute the refractive index at only one specific wavelength, so I would like to also accept float64 values. What is the best way to implement this, without implementing two separate functions?

Edit: The function is using ndarray's syntax for squaring: l.mapv(|x| x.powi(2)), which is unfortunately different than for float64

Edit 2: As asked I included the function body.

Upvotes: 1

Views: 112

Answers (1)

EvilTak
EvilTak

Reputation: 7579

It is definitely possible, although it might take more work than simply adding another function. You will need to first define two helper traits that specify the mapv and powi (much like Squared) operations for both f64 and Array1<f64>:

trait PowI {
    fn powi(&self, e: i32) -> Self;
}

impl PowI for f64 {
    fn powi(&self, e: i32) -> Self {
        f64::powi(*self, e)
    }
}

impl PowI for Array1<f64> {
    fn powi(&self, e: i32) -> Self {
        self.mapv(|a| a.powi(e))
    }
}

trait MapV {
    fn mapv(&self, f: impl FnMut(f64) -> f64) -> Self;
}

impl MapV for f64 {
    fn mapv(&self, mut f: impl FnMut(f64) -> f64) -> Self {
        f(*self)
    }
}

impl MapV for Array1<f64> {
    fn mapv(&self, f: impl FnMut(f64) -> f64) -> Self {
        Array1::mapv(self, f)
    }
}

You can then make your refractive index calculation function be generic over a type T that implements both these traits (i.e. T: PowI + MapV).

Note that you will also need additional bounds on T that specify that it can be added, divided and multiplied with Ts and f64s (which I assume is the type of self.pitch and the elements in the self.A and self.B arrays). You can do this by requiring that T, &'a T or even f64 implement the appropriate Add, Mul and Div traits. For example, your implementation may require that f64: Mul<&'a T, Output=T> and T: Div<f64, Output=T> in addition to many other bounds. The compiler errors will help you figure out exactly which ones you need.

Upvotes: 1

Related Questions