Erunehtar
Erunehtar

Reputation: 1703

Using visitor pattern for checking derived class type?

I am using the visitor pattern to deal with a lot of different AST problems, which turns out to work really well. For instance, I am using it to check for static type. This works well when looking for the exact type, however it doesn't apply to derived classes. i.e. If we have Derived that inherit from Base, asking if a Derived object is a Base fails.

Consider the following C++ code:

#include <iostream>
#include <functional>
#include <memory>

using namespace std;

class Base;
class Derived;

class Visitor {
public:
    virtual void visit(Base& object) = 0;
    virtual void visit(Derived& object) = 0;
};

class EmptyVisitor : public Visitor {
public:
    virtual void visit(Base& object) override {}
    virtual void visit(Derived& object) override {}
};

template <class TYPE> class LogicVisitor : public EmptyVisitor {
public:
    LogicVisitor(function<void(TYPE&)> logic) : EmptyVisitor(), logic(logic) {}
    virtual void visit(TYPE& object) override { logic(object); }
private:
    function<void(TYPE&)> logic;
};

class Base {
public:
    virtual void accept(Visitor* visitor) {
        visitor->visit(*this);
    }
};

class Derived : public Base {
public:
    virtual void accept(Visitor* visitor) override {
        visitor->visit(*this);
    }
};

template <class TYPE> bool is_type(shared_ptr<Base> base)
{
    bool is_type = false;
    LogicVisitor<TYPE> logic_visitor([&](TYPE& object) {
        is_type = true;
    });
    base->accept((Visitor*)&logic_visitor);
    return is_type;
}

int main() {
    auto base = make_shared<Base>();
    auto derived = make_shared<Derived>();
    cout << "is_type<Base>(base) = " << (is_type<Base>(base) ? "true" : "false") << endl;
    cout << "is_type<Derived>(base) = " << (is_type<Derived>(base) ? "true" : "false") << endl;
    cout << "is_type<Base>(derived) = " << (is_type<Base>(derived) ? "true" : "false") << endl;
    cout << "is_type<Derived>(derived) = " << (is_type<Derived>(derived) ? "true" : "false") << endl;
    return 0;
}

It outputs as expected the following result:

is_type<Base>(base) = true
is_type<Derived>(base) = false
is_type<Base>(derived) = false
is_type<Derived>(derived) = true

While this is great to retrieve the static type of an object, how can this be fixed if I wanted is_type<Base>(derived) to return true instead of false so that I can effectively check class inheritance? Is this possible in C++?

Upvotes: 0

Views: 1038

Answers (1)

Rakete1111
Rakete1111

Reputation: 48998

You cannot. The reason why is overload resolution (and your design pattern). Every visitor has two overloads, one for Base& and the second one for Derived&. LogicVisitor overrides the function with the type passed as template parameter, so for Base it will override void visit(Base&).

You want it to override void visit(Derived&) for Base instead (or additionally). But that would require the visitor to find every class that derives from Base, which is not possible at the moment.

You can use std::is_base_of instead:

template<typename T, typename U>
constexpr bool is_type(std::shared_ptr<U>) {
    return std::is_base_of_v<std::decay_t<T>, std::decay_t<U>>;
}

Upvotes: 4

Related Questions