kbg23
kbg23

Reputation: 11

How can I mark Python methods using Decorators?

I have a python class of the form

class MyClass:
    
    def __init__(self, *args, **kwargs):
        self.lst = [MyClass.method1]

    def method1(self):
        pass

    def method2(self):
        pass

I extend this class into a derived class:

class MyDerivedClass(MyClass):
    
    def __init__(self, *args, **kwargs):
        self.lst = [MyDerivedClass.method1, MyDerivedClass.method3]

    def method3(self):
        pass

    def method4(self):
        pass

The list lst contains a specific selection of methods: in the base class, it is only supposed to contain method1, and in the derived class only method1 and method3. I would like to create lst automatically, without having to override __init__ (reason being, overriding __init__ is pretty complicated in my case, and there are a lot of methods). Is there a way to do this with decorators like so?

@might_have_to_use_class_decorator?
class MyClass:
    
    def __init__(self, *args, **kwargs):
        lst = #automatically compile marked methods -> [method1]

    @mark_this_method
    def method1(self):
        pass

    def method2(self):
        pass

@might_have_to_use_class_decorator?
class MyDerivedClass(MyClass):
    
    def __init__(self, *args, **kwargs):
        lst = # automatically compiled, [method1, method3]

    @mark_this_method
    def method3(self):
        pass

    def method4(self):
        pass

Is there a way to do this using decorators, perhaps by decorating our classes?

Upvotes: 1

Views: 409

Answers (1)

Jack Taylor
Jack Taylor

Reputation: 6217

The simple way is to keep a list of method names in a class variable, as @juanpa.arrivillaga suggests, but if that doesn't suit you, you can indeed do it using decorators.

The trick is to dynamically assign a property on the method object using a method decorator, then enumerate all methods using a class decorator.

def mark_method(method):
    method.is_marked = True
    return method

def class_with_marked_methods(cls):
    parent_marked_methods = getattr(cls, "marked_methods", ())
    marked_methods = list(parent_marked_methods)
    for method in cls.__dict__.values():
        if getattr(method, "is_marked", False):
            marked_methods.append(method)
    cls.marked_methods = tuple(marked_methods)
    return cls

@class_with_marked_methods
class MyClass:
    @mark_method
    def method1(self):
        pass

    def method2(self):
        pass

@class_with_marked_methods
class MyDerivedClass(MyClass):
    @mark_method
    def method3(self):
        pass

    def method4(self):
        pass

print([method.__name__ for method in MyDerivedClass.marked_methods])
# prints ["method1", "method3"]

I've made the marked_methods property a tuple so that it can't be altered after it is created, but it could just as easily be a list.

Upvotes: 1

Related Questions