Reputation: 3453
I'm trying to represent an Enum where each key must have a specific type for its values. For example, I have defined this:
from enum import Enum, auto
from typing import Any
class MatchType(Enum):
NATIVE = auto()
DICTIONARY = auto()
LIST = auto()
class MatchResult:
type: MatchType
value: Any
def __init__(self, type: MatchType, value: Any):
self.type = type
self.value = value
Now how I can associate those types to a corresponding value type? What I mean, if a function returns a MatchResult
with type = MatchType.NATIVE
I want to Mypy check that I'm using a float | int | string | bool
as value:
def fn_returs_primitive() -> MathResult:
...
return MathResult(MatchType.NATIVE, [1, 2, 3]) # This CAN NOT happen as NATIVE should be int, float, string or boolean, NOT a list
How could I ensure that in Python?
In Rust, for instance, you can define an Enum where each type has parameters:
use std::collections::HashMap;
enum MatchType<T> {
NATIVE(u32),
DICTIONATY(HashMap<String, T>),
LIST(Vec<T>)
}
Does something similar exist in Python? Any kind of help would be really appreciated
Upvotes: 2
Views: 568
Reputation: 5907
Python has an untagged union type called Union
. This type is considered untagged because there is no information storing which variant of the enum is selected. For your use case, here is an untagged implementation:
from typing import TypeVar, Union
T = TypeVar("T")
MatchType = Union[int, dict[str, T], list[T]]
def get() -> MatchType:
return [1, 2, 3]
def match_on(match_type: MatchType):
if isinstance(match_type, int):
print("An int.")
elif isinstance(match_type, dict):
print("A dict.")
elif isinstance(match_type, list):
print("A list.")
Notice however that we have to iterate through all possible MatchType
s during matching. This is because there is no tag stored with the variants of an untagged union that we can index a map / table by. A naive attempt to do constant-time matching might look like this:
def match_on(match_type: MatchType):
{
int: lambda: print("An int."),
dict: lambda: print("A dictionary."),
list: lambda: print("A list.")
}[type(match_type)]()
but given a subclass of int
, this would throw an IndexError
because the type isn't strictly int
.
To enable constant time matching like the rust compiler might emit for matching on a tagged union, you'd have to mimic a tagged union like this:
from dataclasses import dataclass
from typing import TypeVar, final, Generic, Union
T = TypeVar("T")
@final
@dataclass
class MatchNative:
value: int
@final
@dataclass
class MatchDictionary(Generic[T]):
value: dict[str, T]
# Avoid collision with built in type `List` by prepending `Match`.
@final
@dataclass
class MatchList(Generic[T]):
value: list[T]
MatchType = Union[MatchNative, MatchDictionary[T], MatchList[T]]
def get():
return MatchList([1, 2, 3])
def match_on(match_type: MatchType):
{
MatchNative: lambda: print("An int."),
MatchDictionary: lambda: print("A dictionary."),
MatchList: lambda: print("A list.")
}[type(match_type)]()
The @dataclass
annotations aren't required, it just creates a __init__
for the tags that I then used in the get
function.
Here, we created three classes that include the relevant data for each type, while also serving as a tag themselves because of the extra layer of indirection introduced. These classes are made @final
in order to rule out subclasses of the tags being given as an instance of the Union
. The @final
annotations enable constant-time matching.
Note that both the untagged and tagged implementations are still missing exhaustiveness checking, which rust's match
statement has. Python 3.10 is coming with a match
statement, but I haven't looked into whether mypy will be able to perform exhaustiveness checking with that.
Upvotes: 4