Yot Yot5
Yot Yot5

Reputation: 21

How can I enforce a specific output type for a specific input type in Python?

I've been learning how to use Python type hints and there's a particular use case I'm struggling with.

Let's say I have the following Pydantic models:

from pydantic import BaseModel


class Horse(BaseModel):
    speed: str
    race_wins: int

class HorseWithHat(Horse):
    hat_color: str

class Snake(BaseModel):
    length: str
    poisonous: bool

class SnakeWithHat(Snake):
    hat_color: str

# Etc.

I have various other animal models, each with an associated animal-with-hat model. I now want to implement a function that gives an animal a hat. The type signature would be something like

def give_hat(animal: Animal, hat_color: str) -> AnimalWithHat

where Animal = Union[Horse, Snake, etc.] and AnimalWithHat = Union[HorseWithHat, SnakeWithHat, etc.]. Of course, the issue with this idea is that a Horse could go in and a SnakeWithHat could come out; I want to enforce consistency.

The other idea I had was to create a WithHat generic. The type signature would then be

def give_hat(animal: AnimalTypeVar, hat_color: str) -> WithHat[AnimalTypeVar]

with AnimalTypeVar being a type variable bound by Animal = Union[Horse, Snake, etc.]. This would have the advantage of condensing the repetitious WithHat model definitions, however, I haven't been able to figure out a way to define a generic that works in this way (adding a single attribute to an input type).

I'm hoping I'm missing something simple! Any suggestions?

(I am aware I could just combine the non-hat and hat models, making hat_color an optional attribute, but in my real project this is finicky to deal with. If possible, I'd like a solution with distinct hatless and hatful models.)

Upvotes: 2

Views: 418

Answers (1)

Daniil Fajnberg
Daniil Fajnberg

Reputation: 18663

The elegant solution would of course be intersection types, but we don't have them in the Python type system (yet). This was discussed on SO in a few posts already, see here for example.

Then we'd just define a protocol for WithHat, define an Animal type variable (as you mentioned) for animal and define the return type as an intersection of WithHat and Animal. Alas...

But since you mentioned that there are only a few animal models and you know them all in advance, you could just resort to overloading the signature of give_hat with typing.overload.

from typing import Union, overload
from typing_extensions import TypeAlias


Animal: TypeAlias = Union[Horse, Snake, Cow]
AnimalWithHat: TypeAlias = Union[HorseWithHat, SnakeWithHat, CowWithHat]


@overload
def give_hat(animal: Snake, hat_color: str) -> SnakeWithHat:
    ...


@overload
def give_hat(animal: Cow, hat_color: str) -> CowWithHat:
    ...


@overload
def give_hat(animal: Horse, hat_color: str) -> HorseWithHat:
    ...


def give_hat(animal: Animal, hat_color: str) -> AnimalWithHat:
    return NotImplemented  # actual implementation here

This is not as nice and will obviously blow up in your face, if the number of animals starts to increase, but it works.

Upvotes: 2

Related Questions