MPC
MPC

Reputation: 13

Prevent improper use of a User Defined Data Structure and Making DS resuable

Objective:

I have looked over multiple implementations in websites and it seems people just assume that ints and floats will be the only elements stored in the Max PQ

What if we want to keep track of, lets say Person objects by age, or Transactions by amount. All implementations I have seen will fail during run time if User inserts whatever they Type they want.

Ideally, I want to:

  1. Allow users to re use this the PQ implementation for other data types
  2. Fail fast if used improperly used (such as inserting an instance of a class that implementation does not know how to compare)

In order to implement a PQ, we need to be able to compare objects using some equality operators, lets assume all we need are > , >=

After some research I saw that User defined classes can implement, which I believe would allow for more flexibility:

==  __eq__
!=  __ne__
<   __lt__
<=  __le__
>   __gt__
>=  __ge__

So, could I possibly do a check in the constructor to make sure the equality methods that I need are present, and if not, throw an exception? If I am approaching this incorrectly, what other route should I explore?


Barebone code:


from typing import TypeVar, Generic, List

#define what types our generic class expects, ideally only classes that conform to an interface (define methods needed for proper comparison of variety of classes)
T = TypeVar("T", int, float)


class MaxHeapPriorityQueue(Generic[T]):
    def __init__(self):
## check if __gt__, __ge__, etc, are defined in object type T, if all or some that we need are missing, raise Exception
        self._heap: List[T] = []
        self._insert_pointer: int = 0

    def insert(self, value: T) -> None:
        # TODO IMPLEMENT

    def delete_max(self) -> T:
        ##TODO implement

    def __trickle_up(self, node_index: int) -> None:
        parent_index = self.__calculate_parent_node_index(node_index)

        ## item to item comparison which may fail or lead to logic bugs if user stored non numerical values in Heap
        while node_index > 1 and self._heap[node_index] > self._heap[parent_index]:
            self.__exchange(node_index, parent_index)
            node_index = parent_index
            parent_index = self.__calculate_parent_node_index(node_index)

    @staticmethod
    def __calculate_parent_node_index(child_node_index: int) -> int:
        return child_node_index // 2

    def __exchange(self, node_index_1: int, node_index_2: int) -> None:
        ## TODO implement


Edit Using Protocol, mypy check seems to work, but no exception is raised by typing module and if not isinstance(T, SupportsComparison): raise TypeError('can not instantiate with that type') and execution is not entering if check

Generic DS:

from typing import TypeVar, Generic, List, Protocol, runtime_checkable


@runtime_checkable
class SupportsComparison(Protocol):

    def __lt__(self, other) -> bool: ...

    def __le__(self, other) -> bool: ...

    def __eq__(self, other) -> bool: ...

    def __ne__(self, other) -> bool: ...

    def __ge__(self, other) -> bool: ...

    def __gt__(self, other) -> bool: ...


T = TypeVar("T", bound=SupportsComparison)


class MaxHeapPriorityQueue(Generic[T]):
    def __init__(self):
        if not isinstance(T, SupportsComparison):
            raise TypeError('can not instantiate with that type')

        self._heap: List[T] = []
        # pointer which will add elements in position such that we will always have a complete binary tree. It will
        # point to latest point added
        self._insert_pointer: int = 0

    def insert(self, value: T) -> None:
        # we increment before inserting because pointer is not pointing where next element should be added,
        # it instead points to one less. SO if we have 1 element, it will point to 1. If we have 0 elements it will
        # point to 0.
        self._insert_pointer += 1
        self._heap.insert(self._insert_pointer, value)
        self.__trickle_up(self._insert_pointer)

    def delete_max(self) -> T:
        if self._insert_pointer == 0:
            raise Exception("Can not remove when PQ is empty")
        return self._heap[1]
        ##TODO implement

    def __trickle_up(self, node_index: int) -> None:
        parent_index = self.__calculate_parent_node_index(node_index)

        # we want to stop trickling up if we have reached the root of the binary tree or the node we are trickling up
        # is less than parent
        while node_index > 1 and self._heap[node_index] > self._heap[parent_index]:
            self.__exchange(node_index, parent_index)
            node_index = parent_index
            parent_index = self.__calculate_parent_node_index(node_index)

    @staticmethod
    def __calculate_parent_node_index(child_node_index: int) -> int:
        return child_node_index // 2

    def __exchange(self, node_index_1: int, node_index_2: int) -> None:
        temp = self._heap[node_index_1]
        self._heap[node_index_1] = self._heap[node_index_2]
        self._heap[node_index_2] = temp


Instantiating it:

# Press the green button in the gutter to run the script.
if __name__ == '__main__':
    max_pq = MaxHeapPriorityQueue[Person]()
class Person:

    def __init__(self, name, age):
        self.name = name
        self.age = age

    def __lt__(self, other) -> bool:
        return True  #TODO IMPLEMENT, THIS IS JUST A TEST

    def __le__(self, other) -> bool:
        return True  #TODO IMPLEMENT, THIS IS JUST A TEST

    def __eq__(self, other) -> bool:
        return True  #TODO IMPLEMENT, THIS IS JUST A TEST

    def __ne__(self, other) -> bool:
        return True  #TODO IMPLEMENT, THIS IS JUST A TEST

    def __gt__(self, other) -> bool:
        return True #TODO IMPLEMENT, THIS IS JUST A TEST

    def __ge__(self, other) -> bool:
        return True  #TODO IMPLEMENT, THIS IS JUST A TEST


class Animal:

    def __init__(self, breed):
        self.breed = breed




Executing check:

if __name__ == '__main__':
    max_pq = MaxHeapPriorityQueue[Person]()  ## passes mypy check
    max_pq2 = MaxHeapPriorityQueue[Animal]()  ## fails mypy check

Upvotes: 0

Views: 82

Answers (1)

chepner
chepner

Reputation: 532003

Define a Protocol that requires __ge__ and __gt__ be defined.

@typing.runtime_checkable
class SupportsComparison(typing.Protocol):
    def __ge__(self, other) -> bool:
        ...

    def __gt__(self, other) -> bool:
        ...

T = typing.TypeVar("T", bound=SupportsComparison)


class MaxHeapPriorityQueue(Generic[T]):
    ...

Upvotes: 0

Related Questions