Reputation: 147
When trying to use typing's TypeVar to enable the use of generics with return types, I'm encountering a mypy error whereby the bound
argument isn't accounted for when comparing the types of a dictionary and the expected return type of a function.
Below is an example of the situation I am facing:
from typing import Dict, List, Type, TypeVar
class Bird:
def call(self):
print(self.sound)
class Chicken(Bird):
def __init__(self):
self.sound = "bok bok"
class Owl(Bird):
def __init__(self):
self.sound = "hoot hoot"
T = TypeVar("T", bound=Bird)
class Instantiator:
def __init__(self, birds: List[Type[Bird]]):
self._bird_map: Dict[Type[Bird], Bird] = {}
for bird in birds:
self._bird_map[bird] = bird()
def get_bird(self, bird_type: Type[T]) -> T:
return self._bird_map[bird_type]
Running a mypy validator will show: temp.py:29: error: Incompatible return value type (got "Bird", expected "T")
The Instantiator
is used as a sort of 'tracker' for instantiating one of each bird type. When trying to retrieve the instantiated object based on a class type, this is why the use of generics is required as otherwise later typed fields will complain about using the Bird
class rather than one of Chicken
or Owl
.
Am I incorrectly using TypeVar
here? Is there a different way to approach the structure? Is this an oversight in mypy?
Upvotes: 5
Views: 2921
Reputation: 32253
This is because you have defined a dict containing only base class objects Bird
, but in the function get_bird
you are trying to return an object of type of the base class, while an object of a derived class may be expected. Mypy
will not make Base -> Derived
cast.
You can make __init__
also a generic function.
T = TypeVar("T", bound=Bird)
class Instantiator():
def __init__(self, birds: List[Type[T]]):
self._bird_map: Dict[Type[T], T] = {}
for bird in birds:
self._bird_map[bird] = bird()
def get_bird(self, bird_type: Type[T]) -> T:
return self._bird_map[bird_type]
Or use explicitly cast
:
class Instantiator:
def __init__(self, birds: List[Type[Bird]]):
self._bird_map: Dict[Type[Bird], Bird] = {}
for bird in birds:
self._bird_map[bird] = bird()
def get_bird(self, bird_type: Type[T]) -> T:
return cast(T, self._bird_map[bird_type])
Upvotes: 3