Reputation: 3018
I have an abstract base class GameNodeState
that contains a Type
enum:
import abc
import enum
class GameNodeState(metaclass=abc.ABCMeta):
class Type(enum.Enum):
INIT = enum.auto()
INTERMEDIATE = enum.auto()
END = enum.auto()
The names in the enum are generic because they must make sense for any subclass of GameNodeState
. But when I subclass GameNodeState
, as GameState
and RoundState
, I would like to be able to add concrete aliases to the members of GameNodeState.Type
if the enum is accessed through the subclass. For example, if the GameState
subclass aliases INTERMEDIATE
as ROUND
and RoundState
aliases INTERMEDIATE
as TURN
, I would like the following behaviour:
>>> GameNodeState.Type.INTERMEDIATE
<Type.INTERMEDIATE: 2>
>>> RoundState.Type.TURN
<Type.INTERMEDIATE: 2>
>>> RoundState.Type.INTERMEDIATE
<Type.INTERMEDIATE: 2>
>>> GameNodeState.Type.TURN
AttributeError: TURN
My first thought was this:
class GameState(GameNodeState):
class Type(GameNodeState.Type):
ROUND = GameNodeState.Type.INTERMEDIATE.value
class RoundState(GameNodeState):
class Type(GameNodeState.Type):
TURN = GameNodeState.Type.INTERMEDIATE.value
But enums can't be subclassed.
Note: there are obviously more attributes and methods in the GameNodeState
hierarchy, I stripped it down to the bare minimum here to focus on this particular thing.
Upvotes: 3
Views: 4013
Reputation: 3018
(Original solution below.)
I've extracted an intermediate concept from the code above, namely the concept of enum union. This can be used to obtain the behaviour above, and is also useful in other contexts too. The code can be foud here, and I've asked a Code Review question.
I'll add the code here as well for reference:
import enum
import itertools as itt
from functools import reduce
import operator
from typing import Literal, Union
import more_itertools as mitt
AUTO = object()
class UnionEnumMeta(enum.EnumMeta):
"""
The metaclass for enums which are the union of several sub-enums.
Union enums have the _subenums_ attribute which is a tuple of the enums forming the
union.
"""
@classmethod
def make_union(
mcs, *subenums: enum.EnumMeta, name: Union[str, Literal[AUTO], None] = AUTO
) -> enum.EnumMeta:
"""
Create an enum whose set of members is the union of members of several enums.
Order matters: where two members in the union have the same value, they will
be considered as aliases of each other, and the one appearing in the first
enum in the sequence will be used as the canonical members (the aliases will
be associated to this enum member).
:param subenums: Sequence of sub-enums to make a union of.
:param name: Name to use for the enum class. AUTO will result in a combination
of the names of all subenums, None will result in "UnionEnum".
:return: An enum class which is the union of the given subenums.
"""
subenums = mcs._normalize_subenums(subenums)
class UnionEnum(enum.Enum, metaclass=mcs):
pass
union_enum = UnionEnum
union_enum._subenums_ = subenums
if duplicate_names := reduce(
set.intersection, (set(subenum.__members__) for subenum in subenums)
):
raise ValueError(
f"Found duplicate member names in enum union: {duplicate_names}"
)
# If aliases are defined, the canonical member will be the one that appears
# first in the sequence of subenums.
# dict union keeps last key so we have to do it in reverse:
union_enum._value2member_map_ = value2member_map = reduce(
operator.or_, (subenum._value2member_map_ for subenum in reversed(subenums))
)
# union of the _member_map_'s but using the canonical member always:
union_enum._member_map_ = member_map = {
name: value2member_map[member.value]
for name, member in itt.chain.from_iterable(
subenum._member_map_.items() for subenum in subenums
)
}
# only include canonical aliases in _member_names_
union_enum._member_names_ = list(
mitt.unique_everseen(
itt.chain.from_iterable(subenum._member_names_ for subenum in subenums),
key=member_map.__getitem__,
)
)
if name is AUTO:
name = (
"".join(subenum.__name__.removesuffix("Enum") for subenum in subenums)
+ "UnionEnum"
)
UnionEnum.__name__ = name
elif name is not None:
UnionEnum.__name__ = name
return union_enum
def __repr__(cls):
return f"<union of {', '.join(map(str, cls._subenums_))}>"
def __instancecheck__(cls, instance):
return any(isinstance(instance, subenum) for subenum in cls._subenums_)
@classmethod
def _normalize_subenums(mcs, subenums):
"""Remove duplicate subenums and flatten nested unions"""
# we will need to collapse at most one level of nesting, with the inductive
# hypothesis that any previous unions are already flat
subenums = mitt.collapse(
(e._subenums_ if isinstance(e, mcs) else e for e in subenums),
base_type=enum.EnumMeta,
)
subenums = mitt.unique_everseen(subenums)
return tuple(subenums)
def enum_union(*enums, **kwargs):
return UnionEnumMeta.make_union(*enums, **kwargs)
Once we have that, we can just define the extend_enum
decorator to compute the union of the base enum and the enum "extension", which will result in the desired behaviour:
def extend_enum(base_enum):
def decorator(extension_enum):
return enum_union(base_enum, extension_enum)
return decorator
Usage:
class GameNodeState(metaclass=abc.ABCMeta):
class Type(enum.Enum):
INIT = enum.auto()
INTERMEDIATE = enum.auto()
END = enum.auto()
class RoundState(GameNodeState):
@extend_enum(GameNodeState.Type)
class Type(enum.Enum):
TURN = GameNodeState.Type.INTERMEDIATE.value
class GameState(GameNodeState):
@extend_enum(GameNodeState.Type)
class Type(enum.Enum):
ROUND = GameNodeState.Type.INTERMEDIATE.value
Now all of the examples above produce the same output (plus the added instance check, i.e. isinstance(RoundState.Type.TURN, RoundState.Type)
returns True
).
I think this is a cleaner solution because it doesn't involve mucking around with descriptors; it doesn't need to know anything about the owner class (this works just as well with top-level classes).
Attribute lookup through subclasses and instances of GameNodeState
should automatically link to the correct "extension" (i.e., union), as long as the extension enum is added with the same name as for the GameNodeState
superclass so that it hides the original definition.
Not sure how bad of an idea this is, but here is a solution using a descriptor wrapped around the enum that gets the set of aliases based on the class from which it is being accessed.
class ExtensibleClassEnum:
class ExtensionWrapperMeta(enum.EnumMeta):
@classmethod
def __prepare__(mcs, name, bases):
# noinspection PyTypeChecker
classdict: enum._EnumDict = super().__prepare__(name, bases)
classdict["_ignore_"] = ["base_descriptor", "extension_enum"]
return classdict
# noinspection PyProtectedMember
def __new__(mcs, cls, bases, classdict):
base_descriptor = classdict.pop("base_descriptor")
extension_enum = classdict.pop("extension_enum")
wrapper_enum = super().__new__(mcs, cls, bases, classdict)
wrapper_enum.base_descriptor = base_descriptor
wrapper_enum.extension_enum = extension_enum
base, extension = base_descriptor.base_enum, extension_enum
if set(base._member_map_.keys()) & set(extension._member_map_.keys()):
raise ValueError("Found duplicate names in extension")
# dict union keeps last key so we have to do it in reverse:
wrapper_enum._value2member_map_ = (
extension._value2member_map_ | base._value2member_map_
)
# union of both _member_map_'s but using the canonical member always:
wrapper_enum._member_map_ = {
name: wrapper_enum._value2member_map_[member.value]
for name, member in itertools.chain(
base._member_map_.items(), extension._member_map_.items()
)
}
# aliases shouldn't appear in _member_names_
wrapper_enum._member_names_ = list(
m.name for m in wrapper_enum._value2member_map_.values()
)
return wrapper_enum
def __repr__(self):
# have to use vars() to avoid triggering the descriptor
base_descriptor = vars(self)["base_descriptor"]
return (
f"<extension wrapper enum for {base_descriptor.base_enum}"
f" in {base_descriptor._extension2owner[self]}>"
)
def __init__(self, base_enum):
if not issubclass(base_enum, enum.Enum):
raise TypeError(base_enum)
self.base_enum = base_enum
# The user won't be able to retrieve the descriptor object itself, just
# the enum, so we have to forward calls to register_extension:
self.base_enum.register_extension = staticmethod(self.register_extension)
# mapping of owner class -> extension for subclasses that define an extension
self._extensions: Dict[Type, ExtensibleClassEnum.ExtensionWrapperMeta] = {}
# reverse mapping
self._extension2owner: Dict[ExtensibleClassEnum.ExtensionWrapperMeta, Type] = {}
# add the base enum as the base extension via __set_name__:
self._pending_extension = base_enum
@property
def base_owner(self):
# will be initialised after __set_name__ is called with base owner
return self._extension2owner[self.base_enum]
def __set_name__(self, owner, name):
# step 2 of register_extension: determine the class that defined it
self._extensions[owner] = self._pending_extension
self._extension2owner[self._pending_extension] = owner
del self._pending_extension
def __get__(self, instance, owner):
# Only compute extensions once:
if owner in self._extensions:
return self._extensions[owner]
# traverse in MRO until we find the closest supertype defining an extension
for supertype in owner.__mro__:
if supertype in self._extensions:
extension = self._extensions[supertype]
break
else:
raise TypeError(f"{owner} is not a subclass of {self.base_owner}")
# Cache the result
self._extensions[owner] = extension
return extension
def make_extension(self, extension: enum.EnumMeta):
class ExtensionWrapperEnum(
enum.Enum, metaclass=ExtensibleClassEnum.ExtensionWrapperMeta
):
base_descriptor = self
extension_enum = extension
return ExtensionWrapperEnum
def register_extension(self, extension_enum):
"""Decorator for enum extensions"""
# need a way to determine owner class
# add a temporary attribute that we will use when __set_name__ is called:
if hasattr(self, "_pending_extension"):
# __set_name__ not called after the previous call to register_extension
raise RuntimeError(
"An extension was created outside of a class definition",
self._pending_extension,
)
self._pending_extension = self.make_extension(extension_enum)
return self
Usage would be as follows:
class GameNodeState(metaclass=abc.ABCMeta):
@ExtensibleClassEnum
class Type(enum.Enum):
INIT = enum.auto()
INTERMEDIATE = enum.auto()
END = enum.auto()
class RoundState(GameNodeState):
@GameNodeState.Type.register_extension
class Type(enum.Enum):
TURN = GameNodeState.Type.INTERMEDIATE.value
class GameState(GameNodeState):
@GameNodeState.Type.register_extension
class Type(enum.Enum):
ROUND = GameNodeState.Type.INTERMEDIATE.value
Then:
>>> (RoundState.Type.TURN
... == RoundState.Type.INTERMEDIATE
... == GameNodeState.Type.INTERMEDIATE
... == GameState.Type.INTERMEDIATE
... == GameState.Type.ROUND)
...
True
>>> RoundState.Type.__members__
mappingproxy({'INIT': <Type.INIT: 1>,
'INTERMEDIATE': <Type.INTERMEDIATE: 2>,
'END': <Type.END: 3>,
'TURN': <Type.INTERMEDIATE: 2>})
>>> list(RoundState.Type)
[<Type.INTERMEDIATE: 2>, <Type.INIT: 1>, <Type.END: 3>]
>>> GameNodeState.Type.TURN
Traceback (most recent call last):
...
File "C:\Program Files\Python39\lib\enum.py", line 352, in __getattr__
raise AttributeError(name) from None
AttributeError: TURN
Upvotes: 4