Reputation: 487
So I'm working with a large Django codebase that uses python Enums throughout, e.g.:
from enum import Enum
class Status(Enum):
active = 'active'
# ... later
assert some_django_model_instance.status == Status.active.value # so far so good
...but of course the ".value" part gets forgotten and left off all the time. It would be hard to ditch Enums altogether by now, although they have been more problematic than useful. Is there a way to auto-check for lines like these:
assert some_django_model_instance.status == Status.active # someone forgot ".value" here!
with, say, mypy or pylint or perhaps adding some code/asserts to the base Enum? The problem is, Status.active
doesn't really call any code, it just returns a class, and of course that class is never equal to some_django_model_instance.status
, which is a string.
Upvotes: 2
Views: 1288
Reputation: 16032
You can enforce this by subclassing enum.EnumMeta
:
from enum import EnumMeta, Enum as _Enum
class Enum(_Enum, metaclass=EnumMeta):
def __eq__(self, arg):
if isinstance(arg, self.__class__):
return arg is self
return self.value == arg
Now you never have to call enum.value
for comparison:
class Method(Enum):
GET = 'GET'
POST = 'POST'
>>> get = 'GET'
>>> Method.GET == get
True
>>> get == Method.GET
True
>>> Method.POST == Method.GET
False
This solves the problem in the sense that others won't forget to call .value
for comparison, but creates a larger problem because now it is exponentially more likely that one will forget to call .value
when inserting into a model.
To fix this, I'd recommend also subclassing models.CharField
to create your own enum field:
class EnumField(models.CharField):
def __init__(self, enum, **kwargs):
self.enum = enum
def from_db_value(self, value, expression, connection):
if value is not None:
return self.enum(value)
return None
def to_python(self, value):
if isinstance(value, self.enum):
return value.value
return None
def get_prep_value(self, value):
if isinstance(value, self.enum):
value = value.value
return super().get_prep_value(value)
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
args.append(self.enum)
return name, path, args, kwargs
Now you can also insert into models without calling .value
:
class MyModel(models.Model):
method = EnumField(enum=Method)
>>> MyModel.objects.create(method=Method.GET)
Upvotes: 3
Reputation: 64188
You can make mypy detect these types of problematic comparisons by using the --strict-equality
command line flag/config flag option. With this flag enabled, doing some_str == Status.active
will produce an error like the following:
error: Non-overlapping equality check (left operand type: "str", right operand type: "Literal[Status.active]")
Note: this flag will check for all always-false equality comparisons, not just ones involving enums.
It will, however, be disabled for any equality comparisons where either operand has defined a custom __eq__
method, since the custom method could really be doing anything.
Upvotes: 2