Reputation: 9492
I'm trying to write a class decorator that applies a decorator to all the class' methods:
import inspect
def decorate_func(func):
def wrapper(*args, **kwargs):
print "before"
ret = func(*args, **kwargs)
print "after"
return ret
for attr in "__module__", "__name__", "__doc__":
setattr(wrapper, attr, getattr(func, attr))
return wrapper
def decorate_class(cls):
for name, meth in inspect.getmembers(cls, inspect.ismethod):
setattr(cls, name, decorate_func(meth))
return cls
@decorate_class
class MyClass(object):
def __init__(self):
self.a = 10
print "__init__"
def foo(self):
print self.a
@staticmethod
def baz():
print "baz"
@classmethod
def bar(cls):
print "bar"
obj = MyClass()
obj.foo()
obj.baz()
MyClass.baz()
obj.bar()
MyClass.bar()
It almost works, but @classmethod
S need a special treatment:
$ python test.py
before
__init__
after
before
10
after
baz
baz
before
Traceback (most recent call last):
File "test.py", line 44, in <module>
obj.bar()
File "test.py", line 7, in wrapper
ret = func(*args, **kwargs)
TypeError: bar() takes exactly 1 argument (2 given)
Is there a way to handle this problem nicely ? I inspected @classmethod
decorated methods, but I don't see anything to differentiate them from other "types" of methods.
Update
Here is the complete solution for the record (using descriptors to handle @staticmethod
S and @classmethod
S nicely, and aix's trick to detect @classmethod
S VS normal methods):
import inspect
class DecoratedMethod(object):
def __init__(self, func):
self.func = func
def __get__(self, obj, cls=None):
def wrapper(*args, **kwargs):
print "before"
ret = self.func(obj, *args, **kwargs)
print "after"
return ret
for attr in "__module__", "__name__", "__doc__":
setattr(wrapper, attr, getattr(self.func, attr))
return wrapper
class DecoratedClassMethod(object):
def __init__(self, func):
self.func = func
def __get__(self, obj, cls=None):
def wrapper(*args, **kwargs):
print "before"
ret = self.func(*args, **kwargs)
print "after"
return ret
for attr in "__module__", "__name__", "__doc__":
setattr(wrapper, attr, getattr(self.func, attr))
return wrapper
def decorate_class(cls):
for name, meth in inspect.getmembers(cls):
if inspect.ismethod(meth):
if inspect.isclass(meth.im_self):
# meth is a classmethod
setattr(cls, name, DecoratedClassMethod(meth))
else:
# meth is a regular method
setattr(cls, name, DecoratedMethod(meth))
elif inspect.isfunction(meth):
# meth is a staticmethod
setattr(cls, name, DecoratedClassMethod(meth))
return cls
@decorate_class
class MyClass(object):
def __init__(self):
self.a = 10
print "__init__"
def foo(self):
print self.a
@staticmethod
def baz():
print "baz"
@classmethod
def bar(cls):
print "bar"
obj = MyClass()
obj.foo()
obj.baz()
MyClass.baz()
obj.bar()
MyClass.bar()
Upvotes: 11
Views: 3152
Reputation: 1
The above answers do not apply directly to python3. Based on the other great answers I have been able to come up with the following solution:
import inspect
import types
import networkx as nx
def override_methods(cls):
for name, meth in inspect.getmembers(cls):
if name in cls.methods_to_override:
setattr(cls, name, cls.DecorateMethod(meth))
return cls
@override_methods
class DiGraph(nx.DiGraph):
methods_to_override = ("add_node", "remove_edge", "add_edge")
class DecorateMethod:
def __init__(self, func):
self.func = func
def __get__(self, obj, cls=None):
def wrapper(*args, **kwargs):
ret = self.func(obj, *args, **kwargs)
obj._dirty = True # This is the attribute I want to update
return ret
return wrapper
def __init__(self):
super().__init__()
self._dirty = True
Now anytime a method in the tuple methods_to_override
is called, the dirty flag is set. Of course, anything else can be put there too. It is not necessary to include the DecorateMethod
class in the class whose methods need to be overriden. However, as DecorateMehod
uses specific attributes to the class, I prefer to make a class attribute.
Upvotes: 0
Reputation: 1211
(Too long for a comment)
I took the liberty of adding the ability to specify which methods should get decorated to your solution:
def class_decorator(*method_names):
def wrapper(cls):
for name, meth in inspect.getmembers(cls):
if name in method_names or len(method_names) == 0:
if inspect.ismethod(meth):
if inspect.isclass(meth.im_self):
# meth is a classmethod
setattr(cls, name, VerifyTokenMethod(meth))
else:
# meth is a regular method
setattr(cls, name, VerifyTokenMethod(meth))
elif inspect.isfunction(meth):
# meth is a staticmethod
setattr(cls, name, VerifyTokenMethod(meth))
return cls
return wrapper
Usage:
@class_decorator('some_method')
class Foo(object):
def some_method(self):
print 'I am decorated'
def another_method(self):
print 'I am NOT decorated'
Upvotes: 1
Reputation: 500437
inspect.isclass(meth.im_self)
should tell you whether meth
is a class method:
def decorate_class(cls):
for name, meth in inspect.getmembers(cls, inspect.ismethod):
if inspect.isclass(meth.im_self):
print '%s is a class method' % name
# TODO
...
return cls
Upvotes: 11