exocortex
exocortex

Reputation: 503

Require scalar multiplication of a generic type with f64 in Rust

I feel that I have two halves of a problem and don't know how to fit them together. On the one hand there is the "derive_more" crate which allows to automatically derive "Add" and "Mul" (+ more) traits for specific types. I have used it successfully to implement the function below for a specific type for which I automatically derived these traits.

On the other hand I want to write a generic function that will later use these "Add" and "Mul" (+ more) traits. It does not work however. (For anyone interested, I want to implement a generic Runge-Kutta function for solving differential equations).

pub fn update_rk4<T>(f: fn(&T) -> T, state: &mut T, dt: f64)
where
    T: Sized 
    + std::ops::Mul<Output = T> 
    + std::ops::Add<Output = T> 
    + std::ops::AddAssign,
{
    // runge kutta 4 method creates 4 "helper steps"
    let k1 = f(state);
    let k2 = f(&(*state + k1 * 0.5 * dt));
    let k3 = f(&(*state + k2 * 0.5 * dt));
    let k4 = f(&(*state + k3 * dt));

    *state += (k1 + k2 * 2.0 + k3 * 2.0 + k4) * (1.0 / 6.0 * dt);
}

I am stuck here. I get lots of error messages like this:


error[E0308]: mismatched types
 --> src/integrator.rs:7:32
  |
1 | pub fn update_rk4_with_f<T>(f: fn(&T) -> T, state: &mut T, dt: f64)
  |                          - this type parameter
...
7 |     let k2 = f(&(*state + k1 * 0.5 * dt));
  |                                ^^^ expected type parameter `T`, found floating-point number
  |
  = note: expected type parameter `T`
                       found type `{float}`

What am I doing wrong here? Also: what is the name of the problem that I am trying to solve? I think I am also missing a good search term to put into google.

Edit:

With the help of the nice answer from Matthieu M. below I fixed my problem. I tried replacing "Mul<Output = T>" with "Mul<Rhs = f64, Output = T>", but still got errors. It turned out that the idea was the right one, because in the trait implementation "Rhs" defaults to "T", so it has to be set manually - but without the "Rhs". The solution is simply "Mul<f64, Output = T>". Then with the "Copy" trait it all worked. The final working code looks like this:

pub fn update_rk4<T>(f: fn(&T) -> T, state: &mut T, dt: f64)
where
    T: Sized + Copy
    + std::ops::Mul<f64, Output = T> 
    + std::ops::Add<T, Output = T> 
    + std::ops::AddAssign,
{
    // runge kutta 4 method creates 4 "helper steps"
    let k1 = f(state);
    let k2 = f(&(*state + k1 * 0.5 * dt));
    let k3 = f(&(*state + k2 * 0.5 * dt));
    let k4 = f(&(*state + k3 * dt));

    *state += (k1 + k2 * 2.0 + k3 * 2.0 + k4) * (1.0 / 6.0 * dt);
}

Upvotes: 1

Views: 565

Answers (1)

Matthieu M.
Matthieu M.

Reputation: 299810

Also: what is the name of the problem that I am trying to solve?

I am not aware of any specific name.

What am I doing wrong here?

Look more carefully at the Add trait definition:

pub trait Add<Rhs = Self> {
    type Output;

    fn add(self, rhs: Rhs) -> Self::Output;
}

There is a defaulted parameter (when typing Add) denoting the type of the right-hand argument: unless specified this parameter will be Self.

This is exactly what the error message is telling you:

^^^ expected type parameter `T`, found floating-point number

You specified T: Add<Output = Self>, and thus the right-hand argument is supposed to be a T and 0.5 is a literal, not a T.


You can solve this problem in one of several ways:

  1. Mixed Arithmetic: Implement Add<f32> or Add<f64> for T, and specify T: Add<fXX, Output = Self> as a bound.
  2. Infallible Conversion: Add a supplementary bound T: From<f32> (or T: From<f64>) and convert the literals to T prior to attempting the operation.
  3. Custom Trait: Create a new trait with the constants you need, and implement the trait for all the T you'd like.

While appealing, mixed-arithmetic is probably not the solution you want if T is intended to be a primitive, as you cannot add the implementations yourself.

Infallible conversion works well if any f32 (or f64) can be converted to a T. Using TryFrom would allow fallible conversion, but would also require returning the error, which would worsen ergonomics.

The latter Custom Trait approach is more heavy-handed, but allows circumventing the limitations of the previous approaches.

Which way to use depends on what T is supposed to be. I personally favor Infallible Conversion if possible as it's a lightweight requirement -- doesn't require implementing 5 or 10 traits -- or a Custom Trait if conversion cannot be infallible.

Upvotes: 1

Related Questions