ideasman42
ideasman42

Reputation: 48248

How to write generic functions that take types which are themselves generic?

I'm looking to write a function that takes different types which primarily differ in (const / mut) of a member, which themselves take a generic type.

To simplify the question, I'm looking to write a function which takes either a constant or mutable struct. eg:

pub struct PtrConst<T> {
    ptr: *const T,
}
pub struct PtrMut<T> {
    ptr: *mut T,
}

How could I write a function that takes either a PtrConst<SomeType> or PtrMut<SomeType>?


This snippet is rather long, but I've attempted to simplify it.

Playbook link.

// ---------------------------------------------------------------------------
// Test Case: This isn't working!

// How to make a generic function?
// See below for 'PtrConst' & 'PtrMut'.

pub trait PtrAnyFuncs {
    fn new() -> Self;
    fn is_null(&self) -> bool;
}

pub trait PtrAny:
    Deref +
    Copy +
    Clone +
    PartialEq +
    PtrAnyFuncs +
    {}

impl<TPtr> PtrAny for TPtr where TPtr:
    Deref +
    Copy +
    Clone +
    PartialEq +
    PtrAnyFuncs +
    {}


fn generic_test<T: PtrAny<MyStruct>>(a: T) {
    if a.is_null() {
        println!("Test: NULL");
    } else {
        println!("Test: {} {}", a.my_val_a, a.my_val_b);
    }
}


// ---------------------------------------------------------------------------
// Attempt to use generic function

struct MyStruct {
    pub my_val_a: usize,
    pub my_val_b: usize,
}

fn main() {
    let mut a: MyStruct = MyStruct { my_val_a: 10, my_val_b: 2, };
    let b: MyStruct = MyStruct { my_val_a: 4, my_val_b: 4, };

    let a_ptr = PtrMut::new(&mut a as *mut MyStruct);
    let b_ptr = PtrConst::new(&a as *const MyStruct);

    generic_test(a_ptr);
    generic_test(b_ptr);
}


// ---------------------------------------------------------------------------
// PtrMut

use std::ops::{
    Deref,
    DerefMut,
};

#[repr(C)]
#[derive(Hash)]
pub struct PtrMut<T> {
    ptr: *mut T,
}

impl<T> PtrAnyFuncs for PtrMut<T> {
    #[inline(always)]
    fn new(ptr: *mut T) -> PtrMut<T> {
        PtrMut { ptr: ptr as *mut T }
    }

    #[inline(always)]
    fn is_null(&self) -> bool {
        self.ptr == ::std::ptr::null_mut()
    }
}

impl<T> PtrMut<T> {
    #[inline(always)]
    pub fn null() -> PtrMut<T> {
        PtrMut { ptr: ::std::ptr::null_mut() }
    }

    #[inline(always)]
    pub fn as_pointer(&self) -> *mut T {
        self.ptr
    }

    // only for 'PtrMut'
    #[inline(always)]
    pub fn as_const(&self) -> PtrConst<T> {
        PtrConst::new(self.ptr as *const T)
    }
}

impl<T> Copy for PtrMut<T> { }
impl<T> Clone for PtrMut<T> {
    #[inline(always)]
    fn clone(&self) -> PtrMut<T> { *self }
}

impl<T> Deref for PtrMut<T> {
    type Target = T;

    #[inline(always)]
    fn deref(&self) -> &T {
        unsafe { &*self.ptr }
    }
}

impl<T> DerefMut for PtrMut<T> {
    #[inline(always)]
    fn deref_mut(&mut self) -> &mut T {
        unsafe { &mut *self.ptr }
    }
}

impl<T> PartialEq for PtrMut<T> {
    fn eq(&self, other: &PtrMut<T>) -> bool {
        self.ptr == other.ptr
    }
}

// ---------------------------------------------------------------------------
// PtrConst

#[repr(C)]
#[derive(Hash)]
pub struct PtrConst<T> {
    ptr: *const T,
}

impl<T> PtrAnyFuncs for PtrConst<T> {
    #[inline(always)]
    fn new(ptr: *const T) -> PtrConst<T> {
        PtrConst { ptr: ptr as *const T }
    }

    #[inline(always)]
    fn is_null(&self) -> bool {
        self.ptr == ::std::ptr::null_mut()
    }
}

impl<T> PtrConst<T> {

    #[inline(always)]
    pub fn null() -> PtrConst<T> {
        PtrConst { ptr: ::std::ptr::null_mut() }
    }

    #[inline(always)]
    pub fn as_pointer(&self) -> *const T {
        self.ptr
    }
}

impl<T> Copy for PtrConst<T> { }
impl<T> Clone for PtrConst<T> {
    #[inline(always)]
    fn clone(&self) -> PtrConst<T> { *self }
}

impl<T> Deref for PtrConst<T> {
    type Target = T;

    #[inline(always)]
    fn deref(&self) -> &T {
        unsafe { &*self.ptr }
    }
}

// no DerefMut for PtrConst, only PtrMut
impl<T> PartialEq for PtrConst<T> {
    fn eq(&self, other: &PtrConst<T>) -> bool {
        self.ptr == other.ptr
    }
}

Upvotes: 0

Views: 719

Answers (2)

ideasman42
ideasman42

Reputation: 48248

Thanks to help from @futile & @oli_obk_ on IRC, here is a working example of the code in the question.

  • PtrAny and PtrAnyFuncs needed to take a type.
  • PtrAnyFuncs needed to use assosiated types so the argument to new could be made generic across *mut and *const.
  • Deref needed to declare the type that it de-references to Deref<Target=T>

Working code:

pub trait PtrAnyFuncs<T> {
    type InnerPtr;

    fn new(ptr: Self::InnerPtr) -> Self;
    fn is_null(&self) -> bool;
}

pub trait PtrAny<T>:
    Deref<Target=T> +
    Copy +
    Clone +
    PartialEq +
    PtrAnyFuncs<T> +
    {}

impl<TPtr, T> PtrAny<T> for TPtr where TPtr:
    Deref<Target=T> +
    Copy +
    Clone +
    PartialEq +
    PtrAnyFuncs<T> +
    {}

fn generic_test<T: PtrAny<MyStruct>>(a: T) {
    if a.is_null() {
        println!("Test: NULL");
    } else {
        println!("Test: {} {}", a.my_val_a, a.my_val_b);
    }
}


// ---------------------------------------------------------------------------
// Attempt to use generic function

struct MyStruct {
    pub my_val_a: usize,
    pub my_val_b: usize,
}

fn main() {
    let mut a: MyStruct = MyStruct { my_val_a: 10, my_val_b: 2, };
    let b: MyStruct = MyStruct { my_val_a: 4, my_val_b: 4, };

    let a_ptr = PtrMut::new(&mut a as *mut MyStruct);
    let b_ptr = PtrConst::new(&b as *const MyStruct);

    generic_test(a_ptr);
    generic_test(b_ptr);
}


// ---------------------------------------------------------------------------
// PtrMut

use std::ops::{
    Deref,
    DerefMut,
};

#[repr(C)]
#[derive(Hash)]
pub struct PtrMut<T> {
    ptr: *mut T,
}

impl<T> PtrAnyFuncs<T> for PtrMut<T> {
    type InnerPtr = *const T;

    #[inline(always)]
    fn new(ptr: Self::InnerPtr) -> PtrMut<T> {
        PtrMut { ptr: ptr as *mut T }
    }

    #[inline(always)]
    fn is_null(&self) -> bool {
        self.ptr == ::std::ptr::null_mut()
    }
}

impl<T> PtrMut<T> {

    #[inline(always)]
    pub fn null() -> PtrMut<T> {
        PtrMut { ptr: ::std::ptr::null_mut() }
    }

    #[inline(always)]
    pub fn as_pointer(&self) -> *mut T {
        self.ptr
    }

    // only for 'PtrMut'
    #[inline(always)]
    pub fn as_const(&self) -> PtrConst<T> {
        PtrConst::new(self.ptr as *const T)
    }
}

impl<T> Copy for PtrMut<T> { }
impl<T> Clone for PtrMut<T> {
    #[inline(always)]
    fn clone(&self) -> PtrMut<T> { *self }
}

impl<T> Deref for PtrMut<T> {
    type Target = T;

    #[inline(always)]
    fn deref(&self) -> &T {
        unsafe { &*self.ptr }
    }
}

impl<T> DerefMut for PtrMut<T> {
    #[inline(always)]
    fn deref_mut(&mut self) -> &mut T {
        unsafe { &mut *self.ptr }
    }
}

impl<T> PartialEq for PtrMut<T> {
    fn eq(&self, other: &PtrMut<T>) -> bool {
        self.ptr == other.ptr
    }
}

// ---------------------------------------------------------------------------
// PtrConst

#[repr(C)]
#[derive(Hash)]
pub struct PtrConst<T> {
    ptr: *const T,
}

impl<T> PtrAnyFuncs<T> for PtrConst<T> {
    type InnerPtr = *const T;

    #[inline(always)]
    fn new(ptr: Self::InnerPtr) -> PtrConst<T> {
        PtrConst { ptr: ptr as *const T }
    }

    #[inline(always)]
    fn is_null(&self) -> bool {
        self.ptr == ::std::ptr::null_mut()
    }
}

impl<T> PtrConst<T> {
    #[inline(always)]
    pub fn null() -> PtrConst<T> {
        PtrConst { ptr: ::std::ptr::null_mut() }
    }

    #[inline(always)]
    pub fn as_pointer(&self) -> *const T {
        self.ptr
    }
}

impl<T> Copy for PtrConst<T> { }
impl<T> Clone for PtrConst<T> {
    #[inline(always)]
    fn clone(&self) -> PtrConst<T> { *self }
}

impl<T> Deref for PtrConst<T> {
    type Target = T;

    #[inline(always)]
    fn deref(&self) -> &T {
        unsafe { &*self.ptr }
    }
}

// no DerefMut for PtrConst, only PtrMut
impl<T> PartialEq for PtrConst<T> {
    fn eq(&self, other: &PtrConst<T>) -> bool {
        self.ptr == other.ptr
    }
}

Upvotes: 0

oli_obk
oli_obk

Reputation: 31283

The solution is to make your trait generic over the pointee type:

pub trait PtrAny<T>: ...

impl<T, TPtr> PtrAny<T> for TPtr where TPtr: ...

Note that this doesn't fix your linked code example, because Rust can't abstract over (non-)mutability.

Upvotes: 2

Related Questions