ssokolen
ssokolen

Reputation: 478

How can I write a generic function that takes either an ndarray Array or ArrayView as input?

I am writing a set of mathematical functions using ndarray which I would like to perform on any type of ArrayBase. However, I'm having trouble specifying the traits/types involved.

This basic function works on either OwnedRepr or ViewRepr data:

use ndarray::{prelude::*, Data}; // 0.13.1

fn sum_owned(x: Array<f64, Ix1>) -> f64 {
    x.sum()
}

fn sum_view(x: ArrayView<f64, Ix1>) -> f64 {
    x.sum()
}

fn main() {
    let a = Array::from_shape_vec((4,), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
    println!("{:?}", sum_owned(a.clone()));

    let b = a.slice(s![..]);
    println!("{:?}", sum_view(b));

    // Complains that OwnedRepr is not ViewRepr
    //println!("{:?}", sum_view(a.clone()));
}

I can understand why the commented out section doesn't compile, but I don't understand generics well enough to write something more... generic.

Here is what I tried:

use ndarray::prelude::*;
use ndarray::Data;

fn sum_general<S>(x: ArrayBase<S, Ix1>) -> f64
where
    S: Data,
{
    x.sum()
}

The compiler error suggests that Data is not specific enough, but I just can't parse it well enough to figure out what the solution should be:

error[E0277]: the trait bound `<S as ndarray::data_traits::RawData>::Elem: std::clone::Clone` is not satisfied
 --> src/lib.rs:8:7
  |
6 |     S: Data,
  |             - help: consider further restricting the associated type: `, <S as ndarray::data_traits::RawData>::Elem: std::clone::Clone`
7 | {
8 |     x.sum()
  |       ^^^ the trait `std::clone::Clone` is not implemented for `<S as ndarray::data_traits::RawData>::Elem`

error[E0277]: the trait bound `<S as ndarray::data_traits::RawData>::Elem: num_traits::identities::Zero` is not satisfied
 --> src/lib.rs:8:7
  |
6 |     S: Data,
  |             - help: consider further restricting the associated type: `, <S as ndarray::data_traits::RawData>::Elem: num_traits::identities::Zero`
7 | {
8 |     x.sum()
  |       ^^^ the trait `num_traits::identities::Zero` is not implemented for `<S as ndarray::data_traits::RawData>::Elem`

error[E0308]: mismatched types
 --> src/lib.rs:8:5
  |
4 | fn sum_general<S>(x: ArrayBase<S, Ix1>) -> f64
  |                                            --- expected `f64` because of return type
...
8 |     x.sum()
  |     ^^^^^^^ expected `f64`, found associated type
  |
  = note:         expected type `f64`
          found associated type `<S as ndarray::data_traits::RawData>::Elem`
  = note: consider constraining the associated type `<S as ndarray::data_traits::RawData>::Elem` to `f64`
  = note: for more information, visit https://doc.rust-lang.org/book/ch19-03-advanced-traits.html

Upvotes: 4

Views: 769

Answers (1)

eggyal
eggyal

Reputation: 125975

If you look at the definition of the ndarray::ArrayBase::sum function that you're attempting to invoke:

impl<A, S, D> ArrayBase<S, D>
where
    S: Data<Elem = A>,
    D: Dimension,
{
    pub fn sum(&self) -> A
    where
       A: Clone + Add<Output = A> + Zero
    {
         // etc.
    }
}

It's clear that in your case A = f64 and D = Ix1, but you still need to specify the constraint S: Data<Elem = f64>. Therefore:

use ndarray::prelude::*;
use ndarray::Data;

fn sum_general<S>(x: ArrayBase<S, Ix1>) -> f64
where
    S: Data<Elem = f64>,
{
    x.sum()
}

Which is exactly what the compiler meant when it suggested:

  = note:         expected type `f64`
          found associated type `<S as ndarray::data_traits::RawData>::Elem`
  = note: consider constraining the associated type `<S as ndarray::data_traits::RawData>::Elem` to `f64`

Upvotes: 3

Related Questions