Luper Rouch
Luper Rouch

Reputation: 9492

Writing a class decorator that applies a decorator to all methods

I'm trying to write a class decorator that applies a decorator to all the class' methods:

import inspect


def decorate_func(func):
    def wrapper(*args, **kwargs):
        print "before"
        ret = func(*args, **kwargs)
        print "after"
        return ret
    for attr in "__module__", "__name__", "__doc__":
        setattr(wrapper, attr, getattr(func, attr))
    return wrapper


def decorate_class(cls):
    for name, meth in inspect.getmembers(cls, inspect.ismethod):
        setattr(cls, name, decorate_func(meth))
    return cls


@decorate_class
class MyClass(object):

    def __init__(self):
        self.a = 10
        print "__init__"

    def foo(self):
        print self.a

    @staticmethod
    def baz():
        print "baz"

    @classmethod
    def bar(cls):
        print "bar"


obj = MyClass()
obj.foo()
obj.baz()
MyClass.baz()
obj.bar()
MyClass.bar()

It almost works, but @classmethodS need a special treatment:

$ python test.py
before
__init__
after
before
10
after
baz
baz
before
Traceback (most recent call last):
  File "test.py", line 44, in <module>
    obj.bar()
  File "test.py", line 7, in wrapper
    ret = func(*args, **kwargs)
TypeError: bar() takes exactly 1 argument (2 given)

Is there a way to handle this problem nicely ? I inspected @classmethod decorated methods, but I don't see anything to differentiate them from other "types" of methods.

Update

Here is the complete solution for the record (using descriptors to handle @staticmethodS and @classmethodS nicely, and aix's trick to detect @classmethodS VS normal methods):

import inspect


class DecoratedMethod(object):

    def __init__(self, func):
        self.func = func

    def __get__(self, obj, cls=None):
        def wrapper(*args, **kwargs):
            print "before"
            ret = self.func(obj, *args, **kwargs)
            print "after"
            return ret
        for attr in "__module__", "__name__", "__doc__":
            setattr(wrapper, attr, getattr(self.func, attr))
        return wrapper


class DecoratedClassMethod(object):

    def __init__(self, func):
        self.func = func

    def __get__(self, obj, cls=None):
        def wrapper(*args, **kwargs):
            print "before"
            ret = self.func(*args, **kwargs)
            print "after"
            return ret
        for attr in "__module__", "__name__", "__doc__":
            setattr(wrapper, attr, getattr(self.func, attr))
        return wrapper


def decorate_class(cls):
    for name, meth in inspect.getmembers(cls):
        if inspect.ismethod(meth):
            if inspect.isclass(meth.im_self):
                # meth is a classmethod
                setattr(cls, name, DecoratedClassMethod(meth))
            else:
                # meth is a regular method
                setattr(cls, name, DecoratedMethod(meth))
        elif inspect.isfunction(meth):
            # meth is a staticmethod
            setattr(cls, name, DecoratedClassMethod(meth))
    return cls


@decorate_class
class MyClass(object):

    def __init__(self):
        self.a = 10
        print "__init__"

    def foo(self):
        print self.a

    @staticmethod
    def baz():
        print "baz"

    @classmethod
    def bar(cls):
        print "bar"


obj = MyClass()
obj.foo()
obj.baz()
MyClass.baz()
obj.bar()
MyClass.bar()

Upvotes: 11

Views: 3152

Answers (3)

nicky
nicky

Reputation: 1

The above answers do not apply directly to python3. Based on the other great answers I have been able to come up with the following solution:

import inspect
import types
import networkx as nx


def override_methods(cls):
    for name, meth in inspect.getmembers(cls):
        if name in cls.methods_to_override:
            setattr(cls, name, cls.DecorateMethod(meth))
    return cls


@override_methods
class DiGraph(nx.DiGraph):

    methods_to_override = ("add_node", "remove_edge", "add_edge")

    class DecorateMethod:

        def __init__(self, func):
            self.func = func

        def __get__(self, obj, cls=None):
            def wrapper(*args, **kwargs):
                ret = self.func(obj, *args, **kwargs)
                obj._dirty = True  # This is the attribute I want to update
                return ret
            return wrapper

    def __init__(self):
        super().__init__()
        self._dirty = True

Now anytime a method in the tuple methods_to_override is called, the dirty flag is set. Of course, anything else can be put there too. It is not necessary to include the DecorateMethod class in the class whose methods need to be overriden. However, as DecorateMehod uses specific attributes to the class, I prefer to make a class attribute.

Upvotes: 0

tbm
tbm

Reputation: 1211

(Too long for a comment)

I took the liberty of adding the ability to specify which methods should get decorated to your solution:

def class_decorator(*method_names):

    def wrapper(cls):

        for name, meth in inspect.getmembers(cls):
            if name in method_names or len(method_names) == 0:
                if inspect.ismethod(meth):
                    if inspect.isclass(meth.im_self):
                        # meth is a classmethod
                        setattr(cls, name, VerifyTokenMethod(meth))
                    else:
                        # meth is a regular method
                        setattr(cls, name, VerifyTokenMethod(meth))
                elif inspect.isfunction(meth):
                    # meth is a staticmethod
                    setattr(cls, name, VerifyTokenMethod(meth))

        return cls

    return wrapper

Usage:

@class_decorator('some_method')
class Foo(object):

    def some_method(self):
        print 'I am decorated'

    def another_method(self):
        print 'I am NOT decorated'

Upvotes: 1

NPE
NPE

Reputation: 500437

inspect.isclass(meth.im_self) should tell you whether meth is a class method:

def decorate_class(cls):
    for name, meth in inspect.getmembers(cls, inspect.ismethod):
        if inspect.isclass(meth.im_self):
          print '%s is a class method' % name
          # TODO
        ...
    return cls

Upvotes: 11

Related Questions