Matt
Matt

Reputation: 1632

Overload a method based on init variables

How can I overload the get_data method below to return the correct type based on the init value of data_type instead of returning a union of both types?

from typing import Literal

DATA_TYPE = Literal["wood", "concrete"]

class WoodData: ...
class ConcreteData: ...

class Foo:
    def __init__(self, data_type: DATA_TYPE) -> None:
        self.data_type = data_type

    def get_data(self) -> WoodData | ConcreteData:
        if self.data_type == "wood":
            return WoodData()
        return ConcreteData()

I was thinking this could be done by specifying a generic for Foo. But I'm unsure on implementation details.

I'd prefer not to pass WoodData/ConcreteData directly as a generic. This is because I have many methods returning conditional data types depending on whether the init var is wood or concrete.

To illustrate that last point, I know I could add a generic that takes one of the two return types like so:

from typing import Literal

DATA_TYPE = Literal["wood", "concrete"]

class WoodData: ...
class ConcreteData: ...

class Foo[MY_RETURN_TYPE: WoodData | ConcreteData]:
    def __init__(self, data_type: DATA_TYPE) -> None:
        self.data_type = data_type

    def get_data(self) -> MY_RETURN_TYPE:
        if self.data_type == "wood":
            return WoodData()
        return ConcreteData()

But imagine I have tons of methods conditionally returning different types based on the value of data_type. I don't want to specify each of these as generics. I'd rather overload the methods on the class and have return types accurately inferred.

Lastly, I know I could split this into two separate sub classes, but it would be nice to keep them as one class if possible.

Upvotes: 4

Views: 88

Answers (1)

juanpa.arrivillaga
juanpa.arrivillaga

Reputation: 96257

Ok, for this solution, you annotate self with the generic type you want, both mypy and pyright give similar outputs for reveal_type (i.e., it works with the base class but not the subclass):

from typing import Literal, overload, TypeVar


class WoodData: ...
class ConcreteData: ...

class Foo[T:(Literal['wood'], Literal['concrete'])]:
    data_type: T
    def __init__(self, data_type: T) -> None:
        self.data_type = data_type
    @overload
    def get_data(self: "Foo[Literal['wood']]") -> WoodData:
        ...
    @overload
    def get_data(self: "Foo[Literal['concrete']]") -> ConcreteData:
        ...
    @overload
    def get_data(self) -> WoodData | ConcreteData:
        ...
    def get_data(self):
        if self.data_type == "wood":
            return WoodData()
        return ConcreteData()
    @overload
    def bar(self: "Foo[Literal['wood']]") -> int:
        ...
    @overload
    def bar(self: "Foo[Literal['concrete']]") -> str:
        ...
    @overload
    def bar(self) -> int | str:
        ...
    def bar(self):
        if self.data_type == "wood":
            return 42
        return "42"

reveal_type(Foo('wood').get_data()) # main.py:32: note: Revealed type is "__main__.WoodData"
reveal_type(Foo('concrete').get_data()) # main.py:33: note: Revealed type is "__main__.ConcreteData"
reveal_type(Foo('wood').bar()) # main.py:34: note: Revealed type is "builtins.int"
reveal_type(Foo('concrete').bar()) # main.py:35: note: Revealed type is "builtins.str"

class Bar[T:(Literal['wood'], Literal['concrete'])](Foo[T]):
    pass
# works with inheritance too
reveal_type(Bar('wood').get_data()) # main.py:41: note: Revealed type is "__main__.WoodData"
reveal_type(Bar('concrete').get_data()) # main.py:41: note: Revealed type is "__main__.ConcreteData"
reveal_type(Bar('wood').bar()) # main.py:41: note: Revealed type is "builtins.int"
reveal_type(Bar('concrete').bar()) # main.py:41: note: Revealed type is "builtins.str"

However, mypy won't type check the body of the implementation, and pyright seems to be reporting erroneous errors for the body...

Upvotes: 3

Related Questions