Theodor Johnson
Theodor Johnson

Reputation: 257

Can I have optional trait bounds?

Given a trait that models a conditional probability distribution:

trait Distribution {
    type T;
    fn sample<U>(&self, x: U) -> Self::T;
}

I want to implement the trait for two structs, ConditionalNormal and MultivariateConditionalNormal which model a scalar and vector valued distribution respectively.

Such implementations look like this:

struct ConditionalNormal;

impl Distribution for ConditionalNormal {
    type T = f64;

    fn sample<U>(&self, x: U) -> Self::T {
        0.0
    }
}

struct MultivariateConditionalNormal;

impl Distribution for MultivariateConditionalNormal {
    type T = f64;

    fn sample<U>(&self, x: U) -> Self::T {
        0.0 + x[0]
    }
}

(playground)

However, the implementation for MultivariateConditionalNormal is invalid because the generic x[0] is not indexable. If I add the trait bounds std::ops::Index<usize> the implementation for ConditionalNormal is invalid, because a scalar f64 is not indexable.

I have heard that e.g. the Sized trait accepts optional trait bounds via ?Sized; can I do something similar? Is there any way to resolve this problem?

Upvotes: 4

Views: 720

Answers (2)

MB-F
MB-F

Reputation: 23647

You can change the definition of the trait to

trait Distribution<U> {
    type T;
    fn sample(&self, x: U) -> Self::T;
}

This allows you to implement it on various types with different trait bounds.

impl<U> Distribution<U> for ConditionalNormal {
    // ...
}

impl<U> Distribution<U> for MultivariateConditionalNormal
where
    U: std::ops::Index<usize, Output = f64>,
{
    // ...
}

Playground

Upvotes: 5

Boiethios
Boiethios

Reputation: 42889

You can add a new trait to specify what are the capabilities of U:

trait Distribution {
    type T;
    fn sample<U>(&self, x: U) -> Self::T
    where
        U: Samplable;
}

struct ConditionalNormal;

impl Distribution for ConditionalNormal {
    type T = f64;

    fn sample<U>(&self, x: U) -> Self::T
    where
        U: Samplable,
    {
        0.0.value()
    }
}

struct MultivariateConditionalNormal;

impl Distribution for MultivariateConditionalNormal {
    type T = f64;

    fn sample<U>(&self, x: U) -> Self::T
    where
        U: Samplable,
    {
        0.0 + x.value()
    }
}

trait Samplable {
    fn value(&self) -> f64;
}

impl Samplable for f64 {
    fn value(&self) -> f64 {
        *self
    }
}

impl Samplable for Vec<f64> {
    fn value(&self) -> f64 {
        self[0]
    }
}

fn main() {}

Upvotes: 3

Related Questions