nunam
nunam

Reputation: 47

rust implement trait for multiple types

I am trying to implement a traits for multiple types at once, the only way I found to avoid code duplication is to transform all types to one common struct and implement the trait for that struct as below.

trait Increment {
    fn increment(&self) -> Option<String>;
}

struct NumberWrapper {
    number: String,
}

impl Increment for NumberWrapper {
    fn increment(&self) -> Option<String> {
        let num: Result<u64, _> = self.number.parse();
        match num {
            Err(_) => None,
            Ok(x) => Some((x + 1).to_string())
        }
    }
}

impl<T> From<T> for NumberWrapper where T: ToString {
    fn from(input: T) -> NumberWrapper {
        NumberWrapper { number: input.to_string() }
    }
}

fn main() {
    let number_u8: u8 = 10;
    println!("number_u8 is: {}", NumberWrapper::from(number_u8).increment().unwrap());
    let number_u32: u16 = 10;
    println!("number_u16 is: {}", NumberWrapper::from(number_u32).increment().unwrap());
    let number_u32: u32 = 10;
    println!("number_u32 is: {}", NumberWrapper::from(number_u32).increment().unwrap());
    let number_u64: u64 = 10;
    println!("number_u64 is: {}", NumberWrapper::from(number_u64).increment().unwrap());
}

Is there any other way to do the same?

Upvotes: 0

Views: 1593

Answers (2)

isaactfa
isaactfa

Reputation: 6651

Two ways to do this a little more elegantly come to my mind. First, I'm guessing you'd rather have your trait look something like this:

trait Increment {
    // It would probably be better to take `self` by value if you
    // just want this for numeric types which are cheaply copied,
    // but I'll leave it for generality.
    fn increment(&self) -> Option<Self> where Self: Sized;
}

I will assume this going forward (but please correct me).

The first way uses a pretty simple macro:

macro_rules! impl_increment {
    ($($t:ty),*) => {
        $(
            impl Increment for $t {
                fn increment(&self) -> Option<Self> {
                    self.checked_add(1)
                }
            }
        )*
    }
}

It only matches against one rule which reads any number of types separated by commas and implements the increment method for that type based on the checked_add method numeric primitives have in Rust. You can call this just like this:

// This will create an impl block for each of these types:
impl_increment!{u8, u16, u32, u64, i8, i16, i32, i64}

fn main() {
    let x = 41u32;
    assert_eq!(x.increment(), Some(42));
    let y = -60_000i64;
    assert_eq!(y.increment(), Some(-59_999));
    let z = 255u8;
    assert_eq!(z.increment(), None);
}

Or you can do it similarly to what you were already doing, by converting to and from a common type. In this case by using the Into<u64> and TryFrom<u64> traits which all the unsigned integer types narrower than u64 implement:

use std::convert::TryFrom;

impl<T> Increment for T
where T: Copy + Into<u64> + TryFrom<u64>
{
    fn increment(&self) -> Option<Self> {
        let padded: u64 = (*self).into();
        TryFrom::try_from(padded + 1).ok()
    }
}

fn main() {
    let x = 41u32;
    assert_eq!(x.increment(), Some(42));
    let y = 60_000u64;
    assert_eq!(y.increment(), Some(60_001));
    let z = 255u8;
    assert_eq!(z.increment(), None);
}

This has a lot more runtime overhead and doesn't generalize as nicely (won't work for signed integer types for example). So I'd go with the macro route.

Upvotes: 0

frankplow
frankplow

Reputation: 512

Blanket implementations can be used to implement traits for all types which satisfy some other trait(s). I'm not sure exactly what the trait in your example is meant to describe, but I hope the following example illustrates the idea.

use std::ops::Add;
use num::traits::One;

trait Increment {
    fn increment(&self) -> Option<String>;
}

impl<T> Increment for T
    where T: Add + Copy + One,
          <T as Add>::Output: ToString,
{
    fn increment(&self) -> Option<String> {
        Some((*self + One::one()).to_string())
    }
}

fn main() {
    let number_u8: u8 = 10;
    println!("number_u8 is: {}", number_u8.increment().unwrap());
    let number_u32: u16 = 10;
    println!("number_u16 is: {}", number_u32.increment().unwrap());
    let number_u32: u32 = 10;
    println!("number_u32 is: {}", number_u32.increment().unwrap());
    let number_u64: u64 = 10;
    println!("number_u64 is: {}", number_u64.increment().unwrap());
}

Upvotes: 3

Related Questions