Ivan Balepin
Ivan Balepin

Reputation: 487

Enforcing proper usage of python Enum

So I'm working with a large Django codebase that uses python Enums throughout, e.g.:

from enum import Enum

class Status(Enum):
    active = 'active'

# ... later
assert some_django_model_instance.status == Status.active.value  # so far so good

...but of course the ".value" part gets forgotten and left off all the time. It would be hard to ditch Enums altogether by now, although they have been more problematic than useful. Is there a way to auto-check for lines like these:

assert some_django_model_instance.status == Status.active  # someone forgot ".value" here!

with, say, mypy or pylint or perhaps adding some code/asserts to the base Enum? The problem is, Status.active doesn't really call any code, it just returns a class, and of course that class is never equal to some_django_model_instance.status, which is a string.

Upvotes: 2

Views: 1288

Answers (2)

Lord Elrond
Lord Elrond

Reputation: 16032

You can enforce this by subclassing enum.EnumMeta:

from enum import EnumMeta, Enum as _Enum

class Enum(_Enum, metaclass=EnumMeta):
    def __eq__(self, arg):
        if isinstance(arg, self.__class__):
            return arg is self
        return self.value == arg

Now you never have to call enum.value for comparison:

class Method(Enum):
    GET = 'GET'
    POST = 'POST'

>>> get = 'GET'
>>> Method.GET == get
True
>>> get == Method.GET
True
>>> Method.POST == Method.GET
False

This solves the problem in the sense that others won't forget to call .value for comparison, but creates a larger problem because now it is exponentially more likely that one will forget to call .value when inserting into a model.

To fix this, I'd recommend also subclassing models.CharField to create your own enum field:

class EnumField(models.CharField):
    def __init__(self, enum, **kwargs):
        self.enum = enum

    def from_db_value(self, value, expression, connection):
        if value is not None:
            return self.enum(value)
        return None

    def to_python(self, value):
        if isinstance(value, self.enum):
            return value.value
        return None

    def get_prep_value(self, value):
        if isinstance(value, self.enum):
            value = value.value
        return super().get_prep_value(value)

    def deconstruct(self):
        name, path, args, kwargs = super().deconstruct()
        args.append(self.enum)
        return name, path, args, kwargs

Now you can also insert into models without calling .value:

class MyModel(models.Model):
    method = EnumField(enum=Method)

>>> MyModel.objects.create(method=Method.GET)

Upvotes: 3

Michael0x2a
Michael0x2a

Reputation: 64188

You can make mypy detect these types of problematic comparisons by using the --strict-equality command line flag/config flag option. With this flag enabled, doing some_str == Status.active will produce an error like the following:

error: Non-overlapping equality check (left operand type: "str", right operand type: "Literal[Status.active]")

Note: this flag will check for all always-false equality comparisons, not just ones involving enums.

It will, however, be disabled for any equality comparisons where either operand has defined a custom __eq__ method, since the custom method could really be doing anything.

Upvotes: 2

Related Questions