Reputation: 478
I am writing a set of mathematical functions using ndarray
which I would like to perform on any type of ArrayBase
. However, I'm having trouble specifying the traits/types involved.
This basic function works on either OwnedRepr
or ViewRepr
data:
use ndarray::{prelude::*, Data}; // 0.13.1
fn sum_owned(x: Array<f64, Ix1>) -> f64 {
x.sum()
}
fn sum_view(x: ArrayView<f64, Ix1>) -> f64 {
x.sum()
}
fn main() {
let a = Array::from_shape_vec((4,), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
println!("{:?}", sum_owned(a.clone()));
let b = a.slice(s![..]);
println!("{:?}", sum_view(b));
// Complains that OwnedRepr is not ViewRepr
//println!("{:?}", sum_view(a.clone()));
}
I can understand why the commented out section doesn't compile, but I don't understand generics well enough to write something more... generic.
Here is what I tried:
use ndarray::prelude::*;
use ndarray::Data;
fn sum_general<S>(x: ArrayBase<S, Ix1>) -> f64
where
S: Data,
{
x.sum()
}
The compiler error suggests that Data
is not specific enough, but I just can't parse it well enough to figure out what the solution should be:
error[E0277]: the trait bound `<S as ndarray::data_traits::RawData>::Elem: std::clone::Clone` is not satisfied
--> src/lib.rs:8:7
|
6 | S: Data,
| - help: consider further restricting the associated type: `, <S as ndarray::data_traits::RawData>::Elem: std::clone::Clone`
7 | {
8 | x.sum()
| ^^^ the trait `std::clone::Clone` is not implemented for `<S as ndarray::data_traits::RawData>::Elem`
error[E0277]: the trait bound `<S as ndarray::data_traits::RawData>::Elem: num_traits::identities::Zero` is not satisfied
--> src/lib.rs:8:7
|
6 | S: Data,
| - help: consider further restricting the associated type: `, <S as ndarray::data_traits::RawData>::Elem: num_traits::identities::Zero`
7 | {
8 | x.sum()
| ^^^ the trait `num_traits::identities::Zero` is not implemented for `<S as ndarray::data_traits::RawData>::Elem`
error[E0308]: mismatched types
--> src/lib.rs:8:5
|
4 | fn sum_general<S>(x: ArrayBase<S, Ix1>) -> f64
| --- expected `f64` because of return type
...
8 | x.sum()
| ^^^^^^^ expected `f64`, found associated type
|
= note: expected type `f64`
found associated type `<S as ndarray::data_traits::RawData>::Elem`
= note: consider constraining the associated type `<S as ndarray::data_traits::RawData>::Elem` to `f64`
= note: for more information, visit https://doc.rust-lang.org/book/ch19-03-advanced-traits.html
Upvotes: 4
Views: 769
Reputation: 125975
If you look at the definition of the ndarray::ArrayBase::sum
function that you're attempting to invoke:
impl<A, S, D> ArrayBase<S, D>
where
S: Data<Elem = A>,
D: Dimension,
{
pub fn sum(&self) -> A
where
A: Clone + Add<Output = A> + Zero
{
// etc.
}
}
It's clear that in your case A = f64
and D = Ix1
, but you still need to specify the constraint S: Data<Elem = f64>
. Therefore:
use ndarray::prelude::*;
use ndarray::Data;
fn sum_general<S>(x: ArrayBase<S, Ix1>) -> f64
where
S: Data<Elem = f64>,
{
x.sum()
}
Which is exactly what the compiler meant when it suggested:
= note: expected type `f64` found associated type `<S as ndarray::data_traits::RawData>::Elem` = note: consider constraining the associated type `<S as ndarray::data_traits::RawData>::Elem` to `f64`
Upvotes: 3