Reputation: 1632
How can I overload the get_data
method below to return the correct type based on the init value of data_type
instead of returning a union of both types?
from typing import Literal
DATA_TYPE = Literal["wood", "concrete"]
class WoodData: ...
class ConcreteData: ...
class Foo:
def __init__(self, data_type: DATA_TYPE) -> None:
self.data_type = data_type
def get_data(self) -> WoodData | ConcreteData:
if self.data_type == "wood":
return WoodData()
return ConcreteData()
I was thinking this could be done by specifying a generic for Foo
. But I'm unsure on implementation details.
I'd prefer not to pass WoodData
/ConcreteData
directly as a generic. This is because I have many methods returning conditional data types depending on whether the init var is wood
or concrete
.
To illustrate that last point, I know I could add a generic that takes one of the two return types like so:
from typing import Literal
DATA_TYPE = Literal["wood", "concrete"]
class WoodData: ...
class ConcreteData: ...
class Foo[MY_RETURN_TYPE: WoodData | ConcreteData]:
def __init__(self, data_type: DATA_TYPE) -> None:
self.data_type = data_type
def get_data(self) -> MY_RETURN_TYPE:
if self.data_type == "wood":
return WoodData()
return ConcreteData()
But imagine I have tons of methods conditionally returning different types based on the value of data_type
. I don't want to specify each of these as generics. I'd rather overload the methods on the class and have return types accurately inferred.
Lastly, I know I could split this into two separate sub classes, but it would be nice to keep them as one class if possible.
Upvotes: 4
Views: 88
Reputation: 96257
Ok, for this solution, you annotate self
with the generic type you want, both mypy
and pyright
give similar outputs for reveal_type
(i.e., it works with the base class but not the subclass):
from typing import Literal, overload, TypeVar
class WoodData: ...
class ConcreteData: ...
class Foo[T:(Literal['wood'], Literal['concrete'])]:
data_type: T
def __init__(self, data_type: T) -> None:
self.data_type = data_type
@overload
def get_data(self: "Foo[Literal['wood']]") -> WoodData:
...
@overload
def get_data(self: "Foo[Literal['concrete']]") -> ConcreteData:
...
@overload
def get_data(self) -> WoodData | ConcreteData:
...
def get_data(self):
if self.data_type == "wood":
return WoodData()
return ConcreteData()
@overload
def bar(self: "Foo[Literal['wood']]") -> int:
...
@overload
def bar(self: "Foo[Literal['concrete']]") -> str:
...
@overload
def bar(self) -> int | str:
...
def bar(self):
if self.data_type == "wood":
return 42
return "42"
reveal_type(Foo('wood').get_data()) # main.py:32: note: Revealed type is "__main__.WoodData"
reveal_type(Foo('concrete').get_data()) # main.py:33: note: Revealed type is "__main__.ConcreteData"
reveal_type(Foo('wood').bar()) # main.py:34: note: Revealed type is "builtins.int"
reveal_type(Foo('concrete').bar()) # main.py:35: note: Revealed type is "builtins.str"
class Bar[T:(Literal['wood'], Literal['concrete'])](Foo[T]):
pass
# works with inheritance too
reveal_type(Bar('wood').get_data()) # main.py:41: note: Revealed type is "__main__.WoodData"
reveal_type(Bar('concrete').get_data()) # main.py:41: note: Revealed type is "__main__.ConcreteData"
reveal_type(Bar('wood').bar()) # main.py:41: note: Revealed type is "builtins.int"
reveal_type(Bar('concrete').bar()) # main.py:41: note: Revealed type is "builtins.str"
However, mypy won't type check the body of the implementation, and pyright seems to be reporting erroneous errors for the body...
Upvotes: 3