Beetroot
Beetroot

Reputation: 701

Overload a function with a derived class argument if you only have a pointer to the base class in C++

I have seen people using containers of pointers to the base class to hold groups of objects which share the same virtual functions. Is it possible to use overloaded functions of the derived class with these base class pointers. It is hard to explain what I mean but (I think) easy to show with code

class PhysicsObject // A pure virtual class
{
    // Members of physics object
    // ...
};

class Circle : public PhysicsObject
{
    // Members of circle
    // ...
};

class Box : public PhysicsObject
{
    // Members of box
    // ...
};

// Overloaded functions (Defined elsewhere)
void ResolveCollision(Circle& a, Box& b);
void ResolveCollision(Circle& a, Circle& b);
void ResolveCollision(Box& a, Box& b);

int main()
{
    // Container to hold all our objects 
    std::vector<PhysicsObject*> objects;

    // Create some circles and boxes and add to objects container
    // ...

    // Resolve any collisions between colliding objects
    for (auto& objA : objects)
        for (auto& objB : objects)
            if (objA != objB)
                ResolveCollision(*objA, *objB); // !!! Error !!! Can't resolve overloaded function
}

My first idea was to make these functions be virtual class members also (shown below) but I quickly realised that it has exactly the same issue.

class Circle;
class Box;
class PhysicsObject // A pure virtual class
{
    virtual void ResolveCollision(Circle& a) = 0;
    virtual void ResolveCollision(Box& a) = 0;
    // Members of physics object
    // ...
};

class Box;
class Circle : public PhysicsObject
{
    void ResolveCollision(Circle& a);
    void ResolveCollision(Box& a);
    // Members of circle
    // ...
};

class Circle;
class Box : public PhysicsObject
{
    void ResolveCollision(Circle& a);
    void ResolveCollision(Box& a);
    // Members of box
    // ...
};

From googling the problem it seems like possibly it can be solved using casting but I can't figure out how to find the correct type to cast to (also it is ugly). I suspect I am asking the wrong question and there is a better way to structure my code which sidesteps this problem and achieves the same result.

Upvotes: 3

Views: 2034

Answers (3)

Bathsheba
Bathsheba

Reputation: 234715

The way I'd do this is to build a Extent class that tells you about the physical perimeter of an object, perhaps with respect to its barycentre. Additionally, you'd have

virtual const Extent& getExtent() const = 0;

in the PhysicsObject class. You then implement getExtent once per object type.

Your collision detection line becomes

ResolveCollision(objA->getExtent(), objB->getExtent());

Although, in a sense, this does little more than kick the can down the road as the complexity is pushed to the Extent class, the approach will scale well since you only need to build one method per object.

The alternative double dispatch mechanism is intrusive insofar that a new shape requires adjustment to all existing shapes. Having to recompile the Circle class, for example, if you introduce an Ellipse class, say, is a code smell to me.

Upvotes: 3

nh_
nh_

Reputation: 2241

I am going to sketch an implementation that does not rely on double-dispatch. Instead, it makes use of a table where all functions are registered. This table is then accessed using the dynamic type of the objects (passed as base class).

First, we have some example shapes. Their types are enlisted inside an enum class. Every shape class defines a MY_TYPE as their respective enum entry. Furthermore, they have to implement the base class' pure virtual type method:

enum class ObjectType
{
    Circle,
    Box,
    _Count,
};

class PhysicsObject
{
public:
    virtual ObjectType type() const = 0;
};

class Circle : public PhysicsObject
{
public:
    static const ObjectType MY_TYPE = ObjectType::Circle;

    ObjectType type() const override { return MY_TYPE; }
};

class Box : public PhysicsObject
{
public:
    static const ObjectType MY_TYPE = ObjectType::Box;

    ObjectType type() const override { return MY_TYPE; }
};

Next, you have your collision resolution functions, they have to be implemented depending on the shapes, of course.

void ResolveCircleCircle(Circle* c1, Circle* c2)
{
    std::cout << "Circle-Circle" << std::endl;
}

void ResolveCircleBox(Circle* c, Box* b)
{
    std::cout << "Circle-Box" << std::endl;
}

void ResolveBoxBox(Box* b1, Box* b2)
{
    std::cout << "Box-Box" << std::endl;
}

Note, that we only have Circle-Box here, no Box-Circle, as I assume their collision is detected in the same way. More on the Box-Circle collision case later.

Now to the core part, the function table:

std::function<void(PhysicsObject*,PhysicsObject*)>
    ResolveFunctionTable[(int)(ObjectType::_Count)][(int)(ObjectType::_Count)];
REGISTER_RESOLVE_FUNCTION(Circle, Circle, &ResolveCircleCircle);
REGISTER_RESOLVE_FUNCTION(Circle, Box, &ResolveCircleBox);
REGISTER_RESOLVE_FUNCTION(Box, Box, &ResolveBoxBox);

The table itself is a 2d array of std::functions. Note, that those functions accept pointers to PhysicsObject, not the derived classes. Then, we use some macros for easy registration. Of course, the respective code could be written by hand and I am quite aware of the fact that the use of macros is typically considered bad habit. However, in my opinion, these sorts of things are what macros are good for and as long as you use meaningful names that do not clutter your global namespace, they are acceptable. Notice again that only Circle-Box is registered, not the other way round.

Now to the fancy macro:

#define CONCAT2(x,y) x##y
#define CONCAT(x,y) CONCAT2(x,y)

#define REGISTER_RESOLVE_FUNCTION(o1,o2,fn) \
    const bool CONCAT(__reg_, __LINE__) = []() { \
        int o1type = static_cast<int>(o1::MY_TYPE); \
        int o2type = static_cast<int>(o2::MY_TYPE); \
        assert(o1type <= o2type); \
        assert(!ResolveFunctionTable[o1type][o2type]); \
        ResolveFunctionTable[o1type][o2type] = \
            [](PhysicsObject* p1, PhysicsObject* p2) { \
                    (*fn)(static_cast<o1*>(p1), static_cast<o2*>(p2)); \
            }; \
        return true; \
    }();

The macro defines a uniquely named variable (using the line number), but this variable merely serves to get the code inside the initializing lambda function to be executed. The types (from the ObjectType enum) of the passed two arguments (these are the concrete classes Box and Circle) are taken and used to index the table. The entire mechanism assumes that there is a total order on the types (as defined in the enum) and checks that a function for Circle-Box collision is indeed registered for the arguments in this order. The assert tells you if you are doing it wrong (accidentally registering Box-Circle). Then a lambda function is registered inside the table for this particular pair of types. The function itself takes two arguments of type PhysicsObject* and casts them to the concrete types before invoking the registered function.

Next, we can have a look at how the table is then used. It is now easy to implement a single function that checks collision of any two PhysicsObjects:

void ResolveCollision(PhysicsObject* p1, PhysicsObject* p2)
{
    int p1type = static_cast<int>(p1->type());
    int p2type = static_cast<int>(p2->type());
    if(p1type > p2type) {
        std::swap(p1type, p2type);
        std::swap(p1, p2);
    }
    assert(ResolveFunctionTable[p1type][p2type]);
    ResolveFunctionTable[p1type][p2type](p1, p2);
}

It takes the dynamic types of the argument and passes them to the function registered for those respective types inside the ResolveFunctionTable. Notice, that the arguments are swapped if they are not in order. Thus you are free to invoke ResolveCollision with Box and Circle and it will then internally invoke the function registered for Circle-Box collision.

Lastly, I will give an example of how to use it:

int main(int argc, char* argv[])
{
    Box box;
    Circle circle;

    ResolveCollision(&box, &box);
    ResolveCollision(&box, &circle);
    ResolveCollision(&circle, &box);
    ResolveCollision(&circle, &circle);
}

Easy, isn't it? See this for a working implementation of the above.


Now, what is the advantage of this approach? The above code is basically all you need to support an arbitrary number of shapes. Let's say you are about to add a Triangle. All you have to do is:

  1. Add an entry Triangle to the ObjectType enum.
  2. Implement your ResolveTriangleXXX functions, but you have to do this in all cases.
  3. Register them to your table using the macro: REGISTER_RESOLVE_FUNCTION(Triangle, Triangle, &ResolveTriangleTriangle);

That's it. No need to add further methods to PhysicsObject, no need to implement methods in all existing types.

I am aware of some 'flaws' of this approach like using macros, having a central enum of all types and relying on a global table. The latter case might lead to some trouble if the shape classes are built into multiple shared libraries. However, in my humble opinion, this approach is quite practical (except for very special use cases) since it does not lead to the explosion of code as is the case with other approaches (e.g. double-dispatch).

Upvotes: 2

Jarod42
Jarod42

Reputation: 217275

With double dispatch, it would be something like:

class Circle;
class Box;

// Overloaded functions (Defined elsewhere)
void ResolveCollision(Circle& a, Box& b);
void ResolveCollision(Circle& a, Circle& b);
void ResolveCollision(Box& a, Box& b);
class PhysicsObject // A pure virtual class
{
public:
    virtual ~PhysicsObject() = default;

    virtual void ResolveCollision(PhysicsObject&) = 0;
    virtual void ResolveBoxCollision(Box&) = 0;
    virtual void ResolveCircleCollision(Circle&) = 0;
};

class Circle : public PhysicsObject
{
public:
    void ResolveCollision(PhysicsObject& other) override { return other.ResolveCircleCollision(*this); }
    void ResolveBoxCollision(Box& box) override { ::ResolveCollision(*this, box);}
    void ResolveCircleCollision(Circle& circle) override { ::ResolveCollision(*this, circle);}
    // ...
};

class Box : public PhysicsObject
{
public:
    void ResolveCollision(PhysicsObject& other) override { return other.ResolveBoxCollision(*this); }
    void ResolveBoxCollision(Box& box) override { ::ResolveCollision(box, *this);}
    void ResolveCircleCollision(Circle& circle) override { ::ResolveCollision(circle, *this);}
    // ...
};

Upvotes: 3

Related Questions