Nunnsy
Nunnsy

Reputation: 147

Using Python typing's TypeVar for generically typed returns with bound

When trying to use typing's TypeVar to enable the use of generics with return types, I'm encountering a mypy error whereby the bound argument isn't accounted for when comparing the types of a dictionary and the expected return type of a function.

Below is an example of the situation I am facing:

from typing import Dict, List, Type, TypeVar


class Bird:
    def call(self):
        print(self.sound)


class Chicken(Bird):
    def __init__(self):
        self.sound = "bok bok"


class Owl(Bird):
    def __init__(self):
        self.sound = "hoot hoot"


T = TypeVar("T", bound=Bird)


class Instantiator:
    def __init__(self, birds: List[Type[Bird]]):
        self._bird_map: Dict[Type[Bird], Bird] = {}
        for bird in birds:
            self._bird_map[bird] = bird()

    def get_bird(self, bird_type: Type[T]) -> T:
        return self._bird_map[bird_type]

Running a mypy validator will show: temp.py:29: error: Incompatible return value type (got "Bird", expected "T")

The Instantiator is used as a sort of 'tracker' for instantiating one of each bird type. When trying to retrieve the instantiated object based on a class type, this is why the use of generics is required as otherwise later typed fields will complain about using the Bird class rather than one of Chicken or Owl.

Am I incorrectly using TypeVar here? Is there a different way to approach the structure? Is this an oversight in mypy?

Upvotes: 5

Views: 2921

Answers (1)

alex_noname
alex_noname

Reputation: 32253

This is because you have defined a dict containing only base class objects Bird, but in the function get_bird you are trying to return an object of type of the base class, while an object of a derived class may be expected. Mypy will not make Base -> Derived cast.

You can make __init__ also a generic function.

T = TypeVar("T", bound=Bird)

class Instantiator():
    def __init__(self, birds: List[Type[T]]):
        self._bird_map: Dict[Type[T], T] = {}
        for bird in birds:
            self._bird_map[bird] = bird()

    def get_bird(self, bird_type: Type[T]) -> T:
        return self._bird_map[bird_type]

Or use explicitly cast:

class Instantiator:
    def __init__(self, birds: List[Type[Bird]]):
        self._bird_map: Dict[Type[Bird], Bird] = {}
        for bird in birds:
            self._bird_map[bird] = bird()

    def get_bird(self, bird_type: Type[T]) -> T:
        return cast(T, self._bird_map[bird_type])  

Upvotes: 3

Related Questions