ParkerHarrelson123
ParkerHarrelson123

Reputation: 49

Operator == Overload for Derived Classes in C++

I am writing a program that has different shape classes

There is a base shape class similar to the following:

class Shape
    {
    public:
        Shape(int x, int y, int size, COLORREF colorRef);
        ~Shape();
        bool operator == (const Shape&) const;
        int x() const;
        int y() const;
        int size() const;

    protected:
        int xCoord;
        int yCoord;
        int shapeSize;
        COLORREF color;
    };

And then some derived classes similar to the following:

class Circle : public Shape
    {
    public:
        Circle(int x, int y, int size, COLORREF colorRef) : Shape(x, y, size, colorRef)
        {
            this->radius = (double)shapeSize / 2;
            this->xCenter = (double)xCoord + radius;
            this->yCenter = (double)yCoord - radius;
        }
        ~Circle() {}

    private:
        double radius;
        double xCenter;
        double yCenter;
    };

class Square : public Shape
    {
    public:
        Square(int x, int y, int size, COLORREF colorRef) : Shape(x, y, size, colorRef) {}
        ~Square() {}
    };

class Triangle : public Shape
    {
    public:
        Triangle(int x, int y, int size, COLORREF colorRef) : Shape(x, y, size, colorRef) {}
        ~Triangle() {}
    };

I would like to overload the == operator in the shape class so that I can determine if 2 shapes are identical. If I could assume both shapes being compared were of the same class then I know it would be fairly straight forward, but how do I go about testing whether 2 objects of the different derived classes are equal? For example, how do I determine that Triangle t != Circle c?

Upvotes: 1

Views: 619

Answers (2)

selbie
selbie

Reputation: 104579

Ok, here's an idea for using the curious recurring template pattern to make implementing derived classes easier while allowing the == operator to work as expected. This maybe overkill, but it should work for your scenario.

Start by filling out your base Shape class. Added to your basic definition is an implementation of operator== that invokes a helper called CompareTypesAndDimensions. The function calls into two virtual methods, TypeCompare and Compare.

class Shape
{
public:
    Shape(int x, int y, int size, COLORREF colorRef) : xCoord(x), yCoord(y), shapeSize(size), color(colorRef) {}

    virtual ~Shape() {}; // need at least one virtual member for dynamic_cast

    int x() const { return xCoord; }
    int y() const { return yCoord; }
    int size() const { return shapeSize; }
    COLORREF col() const { return color; };

    bool operator == (const Shape& other) const
    {
        return CompareTypesAndDimensions(other);
    }

    bool BaseShapeCompare(const Shape& other) const
    {
        return ((other.xCoord == xCoord) && (other.yCoord == yCoord) && (other.shapeSize == shapeSize) && (other.color == color));
    }

    virtual bool TypeCompare(const Shape& other) const = 0;
    virtual bool Compare(const Shape& other) const = 0;

    bool CompareTypesAndDimensions(const Shape& other) const
    {
        // make sure the types checks are reciprocals
        // we don't accidently compare a "Square" with a "Rectangle" if they inherit from each other
        if (TypeCompare(other))
        {
            return Compare(other);
        }
        return false;
    }

protected:
    int xCoord;
    int yCoord;
    int shapeSize;
    COLORREF color;
};

The idea being with the above is that Circle, Triangle, and Square could just implement their own version of TypeCompare and Compare and be done with it. But wait! What if we could save some typing by having a template base class do some work for us - especially for validating that both compared instances are of the same type. And not having to a stock Compare function for the simpler types such as Square and Triangle.

Let's introduce a template class that inherits from Shape. This class, ShapeComparable provides the implementations for Compare and TypeCompare. The only thing it needs the concrete class below it to deal with is a method to handle comparing its own methods.

template <typename T>
class ShapeComparable : public Shape
{
public:

    ShapeComparable(int x, int y, int size, COLORREF colorRef) : Shape(x, y,size,colorRef)
    {}

    bool TypeCompare(const Shape& other) const override
    {
        auto pOtherCastToDerived = dynamic_cast<const T*>(&other);
        return (pOtherCastToDerived != nullptr);
    }

    bool Compare(const Shape& other) const override
    {
        if (BaseShapeCompare(other))
        {
            auto pOtherCastToDerived = dynamic_cast<const T*>(&other);
            if (pOtherCastToDerived)
            {
                return this->CompareDerived(*pOtherCastToDerived);
            }
        }
        return false;
    }

    // derived classes that don't have members to compare will just inherit this member
    virtual bool CompareDerived(const T& other) const
    {
        return true;
    }
};

The magic with the above is that TypeCompare utilizes a dynamic_cast to validate if the two instances being compared are of the same type. If you try to compare a Triangle to a Circle, the dynamic cast fails. Hence, operator== will return false.

Now let's see what the rest of the classes look like. Start with Circle, it inherits from ShapeComparable and provides an implementation for CompareDerived.

class Circle : public ShapeComparable<Circle>
{
public:
    Circle(int x, int y, int size, COLORREF colorRef) : ShapeComparable(x,y,size,colorRef)
    {
        this->radius = (double)shapeSize / 2;
        this->xCenter = (double)xCoord + radius;
        this->yCenter = (double)yCoord - radius;
    }

    bool CompareDerived(const Circle& other) const
    {
        // BaseCompare has already been invoked by the time this method is invoked.
        return ((other.radius == radius) && (other.xCenter == xCenter) && (other.yCenter == yCenter));
    }

private:
    double radius;
    double xCenter;
    double yCenter;
};

But Triangle and Square are as simple as it gets.

class Triangle : public ShapeComparable<Triangle>
{
public:
    Triangle(int x, int y, int size, COLORREF colorRef) : ShapeComparable(x, y, size, colorRef) {}
};

class Square : public ShapeComparable<Square>
{
    Square(int x, int y, int size, COLORREF colorRef) : ShapeComparable(x, y, size, colorRef) {}
};

And if you ever need to introduce a new property to Triangle and Square, you just need to provide a CompareDerived method.

The above works with the assumption is that you wouldn't have additional shapes derived from another concrete shape class. Otherwise, the CompareType function won't be reciprocal when comparing a Square to a Rhombus.

Upvotes: 0

PiotrNycz
PiotrNycz

Reputation: 24420

You have to determine which function to call based on type of two objects. This pattern in C++ is called double-dispatch (or Visitor pattern).

The most common implementation assumes that all derived classes (shapes in your example) are known - so you can list them in base class:

class Circle;
class Rectangle;
// all shapes here
class Shape {
public:
   virtual ~Shape() = default; // good habit is to add virtual destructor to all polymorphic classes (those with virtual methods)

   bool operator == (const Shape& other) const {
      return equalTo(other);
   }
   
   virtual bool equalTo(const Shape& other) const = 0;
   virtual bool doEqualTo(const Circle& other) const { return false; }
   virtual bool doEqualTo(const Rectangle& other) const { return false; }
   // etc.. for all other shapes

};

class Circle : public Shape {
  // ...
protected:
     virtual bool equalTo(const Shape& other) const 
     {  
         return other.doEqualTo(*this); // call doEqualTo(Circle) - first virtual dispatch
     }
     virtual bool doEqualTo(const Circle& other) const 
     {  
         return other.center == center && other.radius == radius; // second virtual dispatch
     }

};

As you can see - to perform action - you have to call 2 virtual functions (so double-dispatch)

Upvotes: 1

Related Questions