Kyle
Kyle

Reputation: 4577

How to accomplish covariant return types when returning a shared_ptr?

using namespace boost;

class A {};
class B : public A {};

class X {
  virtual shared_ptr<A> foo();
};

class Y : public X {
  virtual shared_ptr<B> foo();
};

The return types aren't covariant (nor are they, therefore, legal), but they would be if I was using raw pointers instead. What's the commonly accepted idiom to work around this, if there is one?

Upvotes: 19

Views: 2778

Answers (4)

Golden Rockefeller
Golden Rockefeller

Reputation: 313

Here is my take on this question:

You can get closer to truly covariant smart pointers return types by creating an automated implementation class that reroutes pointers automatically. This way also works well with custom deleters, and this would also work when the derived class doesn't have sole shared ownership of the pointer they are returning. The code would look like this:

class Base{
    virtual void print_impl(std::shared_ptr<Base>& base_ptr) = 0;
public:
    std::shared_ptr<Base> print() {
        std::shared_ptr<Base> base;
        print_impl(base);
        return base;
    }
};

template <typename TBase>
class BaseImpl : public Base{
    void print_impl(std::shared_ptr<Base>& base_ptr) override {
        auto tbase = create();
        base_ptr = tbase;
    }
public:
    virtual std::shared_ptr<TBase> print() = 0;
 };

class Derived : public BaseImpl<Derived> {
public:
    std::shared_ptr<Derived> print();
};

In the above example, all derived classes only have to implement std::shared_ptr print(); and inherit from BaseImpl, where "Derived" can be compatible type.

For an executable example:

#include <iostream>
#include <memory>

class Base {
public:
    virtual ~Base() {std::cout << "Destroyed\n";}
    virtual void print() { std::cout << "Base\n"; }
};

class Derived : public Base {
public:
    void print() override { std::cout << "Derived\n"; }
};

class IFactory{
public:
    std::shared_ptr<Base> create() {
        std::shared_ptr<Base> base;
        create_impl(base);
        return base;
    }
private:
    virtual void create_impl(std::shared_ptr<Base>& base_ptr) = 0;
};

template<typename TBase>
class UFactory : public IFactory{
public:
    virtual std::shared_ptr<TBase> create() = 0;
private:
    virtual void create_impl(std::shared_ptr<Base>& base_ptr) {
        auto tbase = create();
        shared_ptr = tbase;
    }
};
class Factory : public UFactory<Base>{
public:
    std::shared_ptr<Base> create() override {
        return std::make_shared<Base>();
    }
};
class DerivedFactory : public UFactory<Derived> {
public:
    std::shared_ptr<Derived> create() override {
        return std::make_shared<Derived>();
    }
};
int main() {
    {
        std::shared_ptr<IFactory> factory = std::make_shared<DerivedFactory>();
        std::shared_ptr<Base> base = factory->create();
        std::cout << typeid(factory->create().get()).name() << std::endl; // Output: Base
        base->print(); // Output: Derived
    }
    std::cout << "-----------" << std::endl;
    {
        std::shared_ptr<DerivedFactory> factory = std::make_shared<DerivedFactory>();
        std::shared_ptr<Base> base = factory->create();
        std::cout << typeid(factory->create().get()).name() << std::endl; // Output: Derived
        base->print(); // Output: Derived
    }
    return 0;
}

Upvotes: 0

quant_dev
quant_dev

Reputation: 6231

I just return a bare pointer and wrap it immediately in the shared pointer.

Upvotes: -1

sdkljhdf hda
sdkljhdf hda

Reputation: 1407

Not directly, but you can fake it by making the actual virtual functions inaccessible from outside the class and wrapping the virtual function call into a non-virtual function. Downside is that you'll have to remember to implement this wrapper function on each derived class. But you could get around this by puting both the virtul function declaration and the wrapper into the macro.

using namespace boost; // for shared_ptr, make_shared and static_pointer_cast.

// "Fake" implementation of the clone() function.
#define CLONE(MyType) \
    shared_ptr<MyType> clone() \
    { \
        shared_ptr<Base> res = clone_impl(); \
        assert(dynamic_cast<MyType*>(res.get()) != 0); \
        return static_pointer_cast<MyType>(res); \
    }

class Base 
{
protected:
    // The actual implementation of the clone() function. 
    virtual shared_ptr<Base> clone_impl() { return make_shared<Base>(*this); }

public:
    // non-virtual shared_ptr<Base> clone();
    CLONE(Base)
};

class Derived : public Base
{
protected:
    virtual shared_ptr<Base> clone_impl() { return make_shared<Derived>(*this); }

public:
    // non-virtual shared_ptr<Derived> clone();
    CLONE(Derived)
};


int main()
{
    shared_ptr<Derived> p = make_shared<Derived>();
    shared_ptr<Derived> clone = p->clone();

    return 0;
}

Upvotes: 4

Potatoswatter
Potatoswatter

Reputation: 137810

I think that a solution is fundamentally impossible because covariance depends on pointer arithmetic which is incompatible with smart pointers.

When Y::foo returns shared_ptr<B> to a dynamic caller, it must be cast to shared_ptr<A> before use. In your case, a B* can (probably) simply be reinterpreted as an A*, but for multiple inheritance, you would need some magic to tell C++ about static_cast<A*>(shared_ptr<B>::get()).

Upvotes: 11

Related Questions