Reputation: 4138
I'm trying to write a decorator that can be added to instance methods and non-instance methods alike. I have reduced my code to a minimum example that demonstrates my point
def call(fn):
def _impl(*args, **kwargs):
return fn(*args, **kwargs)
fn.call = _impl
return fn
class Foo(object):
@call
def bar(self):
pass
Foo().bar.call()
This gives the beautiful error
Traceback (most recent call last):
File "/tmp/511749370/main.py", line 14, in <module>
Foo().bar.call()
File "/tmp/511749370/main.py", line 3, in _impl
return fn(*args, **kwargs)
TypeError: bar() missing 1 required positional argument: 'self'
Is it possible to do something like this without resorting to
Foo.bar.call(Foo())
Or is that my only option?
Upvotes: 0
Views: 187
Reputation: 43246
You have to implement your decorator as a class and implement the descriptor protocol. Basically, the descriptor __get__
function is what's responsible for creating bound methods. By overriding this function, you get access to self
and can create a bound copy of the call
function.
The following implementation does exactly that. The Foo
instance is saved in the __self__
attribute. The decorator has a __call__
method which calls the decorated function, and a call
method which does the same thing.
import inspect
import functools
from copy import copy
class call:
def __init__(self, func):
self.func = func
self.__self__ = None # "__self__" is also used by bound methods
def __call__(self, *args, **kwargs):
# if bound to on object, pass it as the first argument
if self.__self__ is not None:
args = (self.__self__,) + args
return self.func(*args, **kwargs)
def call(self, *args, **kwargs):
self(*args, **kwargs)
def __get__(self, obj, cls):
if obj is None:
return self
# create a bound copy of the decorator
bound = copy(self)
bound.__self__ = obj
# update __doc__ and similar attributes
functools.wraps(bound.func)(bound)
bound.__signature__ = inspect.signature(bound.func)
# add the bound instance to the object's dict so that
# __get__ won't be called a 2nd time
setattr(obj, self.func.__name__, bound)
return bound
Test:
class Foo(object):
@call
def bar(self):
print('bar')
@call
def foo():
print('foo')
Foo().bar.call() # output: bar
foo() # output: foo
Upvotes: 2