jacwah
jacwah

Reputation: 2847

Python set intersection and __eq__

According to this page, set.intersection test for element equality using the __eq__ method. Can anyone then explain to me why this fails?

>>> Class Foo(object):
>>>     def __eq__(self, other):
>>>         return True
>>>
>>> set([Foo()]).intersection([Foo()])
set([])

Using 2.7.3. Is there another (not overly complex) way to do this?

Upvotes: 1

Views: 2042

Answers (1)

User
User

Reputation: 14873

If you overwrite __eq__ you should always overwrite __hash__, too.

"If a == b, then it must be the case that hash(a) == hash(b), else sets and dictionaries will fail." Eric

__hash__ is used to generate an integer out of an object. This is used to put the keys of a dict or the elements of sets into buckets so that one can faster find them.

If you do not overwrite __hash__, the default algorithm creates different hash-integers although the objects are equal.

In your case I would do this:

class Foo(object):
    def __eq__(self, other):
        return type(self) == type(other)
    def __hash__(self):
        return 1

Because all objects of your class are equal to every other object of that class they must all be in the same bucket(1) in the set. This way in returns also True.

What should __eq__ be like:

  • if you only compare Foo objects

    def __eq__(self, other):
        return self.number == other.number
    
  • if you also compare Foo objects to other objects:

    def __eq__(self, other):
        return type(self) == type(other) and self.number == other.number
    
  • if you have different classes with different algorithms for equal, I recommend double-dispatch.

    class Foo:
        def __eq__(self, other):
            return hasattr(other, '_equals_foo') and other._equals_foo(self)
        def _equals_foo(self, other):
            return self.number == other.number
        def _equals_bar(self, other):
            return False # Foo never equals Bar
    class Bar:
        def __eq__(self, other):
            return hasattr(other, '_equals_bar') and other._equals_bar(self)
        def _equals_foo(self, other):
            return False # Foo never equals Bar
        def _equals_bar(self, other):
            return True # Bar always equals Bar
    

    This way both a and b in a == b decide what equal means.

Upvotes: 5

Related Questions