Genarito
Genarito

Reputation: 3453

Correct way to represent an Enum with values of specific types

I'm trying to represent an Enum where each key must have a specific type for its values. For example, I have defined this:

from enum import Enum, auto
from typing import Any

class MatchType(Enum):
    NATIVE = auto()
    DICTIONARY = auto()
    LIST = auto()

class MatchResult:
    type: MatchType
    value: Any

    def __init__(self, type: MatchType, value: Any):
        self.type = type
        self.value = value

Now how I can associate those types to a corresponding value type? What I mean, if a function returns a MatchResult with type = MatchType.NATIVE I want to Mypy check that I'm using a float | int | string | bool as value:

def fn_returs_primitive() -> MathResult:
    ...
    return MathResult(MatchType.NATIVE, [1, 2, 3])  # This CAN NOT happen as NATIVE should be int, float, string or boolean, NOT a list

How could I ensure that in Python?

In Rust, for instance, you can define an Enum where each type has parameters:

use std::collections::HashMap;

enum MatchType<T> {
    NATIVE(u32),
    DICTIONATY(HashMap<String, T>),
    LIST(Vec<T>)
}

Does something similar exist in Python? Any kind of help would be really appreciated

Upvotes: 2

Views: 568

Answers (1)

Mario Ishac
Mario Ishac

Reputation: 5907

Python has an untagged union type called Union. This type is considered untagged because there is no information storing which variant of the enum is selected. For your use case, here is an untagged implementation:

from typing import TypeVar, Union

T = TypeVar("T")
MatchType = Union[int, dict[str, T], list[T]]

def get() -> MatchType:
    return [1, 2, 3]

def match_on(match_type: MatchType):
    if isinstance(match_type, int):
        print("An int.")
    elif isinstance(match_type, dict):
        print("A dict.")
    elif isinstance(match_type, list):
        print("A list.")

Notice however that we have to iterate through all possible MatchTypes during matching. This is because there is no tag stored with the variants of an untagged union that we can index a map / table by. A naive attempt to do constant-time matching might look like this:

def match_on(match_type: MatchType):
    {
        int: lambda: print("An int."),
        dict: lambda: print("A dictionary."),
        list: lambda: print("A list.")
    }[type(match_type)]()

but given a subclass of int, this would throw an IndexError because the type isn't strictly int.

To enable constant time matching like the rust compiler might emit for matching on a tagged union, you'd have to mimic a tagged union like this:

from dataclasses import dataclass
from typing import TypeVar, final, Generic, Union

T = TypeVar("T")

@final
@dataclass
class MatchNative:
    value: int

@final
@dataclass
class MatchDictionary(Generic[T]):
    value: dict[str, T]

# Avoid collision with built in type `List` by prepending `Match`.
@final
@dataclass
class MatchList(Generic[T]):
    value: list[T]

MatchType = Union[MatchNative, MatchDictionary[T], MatchList[T]]

def get():
    return MatchList([1, 2, 3])

def match_on(match_type: MatchType):
    {
        MatchNative: lambda: print("An int."),
        MatchDictionary: lambda: print("A dictionary."),
        MatchList: lambda: print("A list.")
    }[type(match_type)]()

The @dataclass annotations aren't required, it just creates a __init__ for the tags that I then used in the get function.

Here, we created three classes that include the relevant data for each type, while also serving as a tag themselves because of the extra layer of indirection introduced. These classes are made @final in order to rule out subclasses of the tags being given as an instance of the Union. The @final annotations enable constant-time matching.

Note that both the untagged and tagged implementations are still missing exhaustiveness checking, which rust's match statement has. Python 3.10 is coming with a match statement, but I haven't looked into whether mypy will be able to perform exhaustiveness checking with that.

Upvotes: 4

Related Questions