Reputation: 816
I'm trying to overload the __init__()
method of a subclass of an enum. Strangely, the pattern that work with a normal class doesn't work anymore with Enum.
The following show the desired pattern working with a normal class:
class Integer:
def __init__(self, a):
"""Accepts only int"""
assert isinstance(a, int)
self.a = a
def __repr__(self):
return str(self.a)
class RobustInteger(Integer):
def __init__(self, a):
"""Accepts int or str"""
if isinstance(a, str):
super().__init__(int(a))
else:
super().__init__(a)
print(Integer(1))
# 1
print(RobustInteger(1))
# 1
print(RobustInteger('1'))
# 1
The same pattern then breaks if used with an Enum:
from enum import Enum
from datetime import date
class WeekDay(Enum):
MONDAY = 0
TUESDAY = 1
WEDNESDAY = 2
THURSDAY = 3
FRIDAY = 4
SATURDAY = 5
SUNDAY = 6
def __init__(self, value):
"""Accepts int or date"""
if isinstance(value, date):
super().__init__(date.weekday())
else:
super().__init__(value)
assert WeekDay(0) == WeekDay.MONDAY
assert WeekDay(date(2019, 4, 3)) == WeekDay.MONDAY
# ---------------------------------------------------------------------------
# TypeError Traceback (most recent call last)
# /path/to/my/test/file.py in <module>()
# 27
# 28
# ---> 29 class WeekDay(Enum):
# 30 MONDAY = 0
# 31 TUESDAY = 1
# /path/to/my/virtualenv/lib/python3.6/enum.py in __new__(metacls, cls, bases, classdict)
# 208 enum_member._name_ = member_name
# 209 enum_member.__objclass__ = enum_class
# --> 210 enum_member.__init__(*args)
# 211 # If another member with the same value was already defined, the
# 212 # new member becomes an alias to the existing one.
# /path/to/my/test/file.py in __init__(self, value)
# 40 super().__init__(date.weekday())
# 41 else:
# ---> 42 super().__init__(value)
# 43
# 44
# TypeError: object.__init__() takes no parameters
Upvotes: 13
Views: 15138
Reputation: 531165
You have to overload the _missing_
hook. All instances of WeekDay
are created when the class is first defined; WeekDay(date(...))
is an indexing operation rather than a creation operation, and __new__
is initially looking for pre-existing values bound to the integers 0 to 6. Failing that, it calls _missing_
, in which you can convert the date
object into such an integer.
class WeekDay(Enum):
MONDAY = 0
TUESDAY = 1
WEDNESDAY = 2
THURSDAY = 3
FRIDAY = 4
SATURDAY = 5
SUNDAY = 6
@classmethod
def _missing_(cls, value):
if isinstance(value, date):
return cls(value.weekday())
return super()._missing_(value)
A few examples:
>>> WeekDay(date(2019,3,7))
<WeekDay.THURSDAY: 3>
>>> assert WeekDay(date(2019, 4, 1)) == WeekDay.MONDAY
>>> assert WeekDay(date(2019, 4, 3)) == WeekDay.MONDAY
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
AssertionError
(Note: _missing_
is not available prior to Python 3.6.)
Prior to 3.6, it seems you can override EnumMeta.__call__
to make the same check, but I'm not sure if this will have unintended side effects. (Reasoning about __call__
always makes my head spin a little.)
# Silently convert an instance of datatime.date to a day-of-week
# integer for lookup.
class WeekDayMeta(EnumMeta):
def __call__(cls, value, *args, **kwargs):
if isinstance(value, date):
value = value.weekday())
return super().__call__(value, *args, **kwargs)
class WeekDay(Enum, metaclass=WeekDayMeta):
MONDAY = 0
TUESDAY = 1
WEDNESDAY = 2
THURSDAY = 3
FRIDAY = 4
SATURDAY = 5
SUNDAY = 6
Upvotes: 15
Reputation: 17267
There is a much better answer, but I post this anyway as it might be helpful for understanding the issue.
The docs gives this hint:
EnumMeta creates them all while it is creating the Enum class itself, and then puts a custom new() in place to ensure that no new ones are ever instantiated by returning only the existing member instances.
So we have to wait with redefining __new__
until the class is created. With some ugly patching this passes the test:
from enum import Enum
from datetime import date
class WeekDay(Enum):
MONDAY = 0
TUESDAY = 1
WEDNESDAY = 2
THURSDAY = 3
FRIDAY = 4
SATURDAY = 5
SUNDAY = 6
wnew = WeekDay.__new__
def _new(cls, value):
if isinstance(value, date):
return wnew(cls, value.weekday()) # not date.weekday()
else:
return wnew(cls, value)
WeekDay.__new__ = _new
assert WeekDay(0) == WeekDay.MONDAY
assert WeekDay(date(2019, 3, 4)) == WeekDay.MONDAY # not 2019,4,3
Upvotes: 3