Lily Mara
Lily Mara

Reputation: 4138

Decorator that gets self

I'm trying to write a decorator that can be added to instance methods and non-instance methods alike. I have reduced my code to a minimum example that demonstrates my point

def call(fn):
    def _impl(*args, **kwargs):
        return fn(*args, **kwargs)

    fn.call = _impl

    return fn

class Foo(object):
    @call
    def bar(self):
        pass

Foo().bar.call()

This gives the beautiful error

Traceback (most recent call last):
  File "/tmp/511749370/main.py", line 14, in <module>
    Foo().bar.call()
  File "/tmp/511749370/main.py", line 3, in _impl
    return fn(*args, **kwargs)
TypeError: bar() missing 1 required positional argument: 'self'

Is it possible to do something like this without resorting to

Foo.bar.call(Foo())

Or is that my only option?

Upvotes: 0

Views: 187

Answers (1)

Aran-Fey
Aran-Fey

Reputation: 43246

You have to implement your decorator as a class and implement the descriptor protocol. Basically, the descriptor __get__ function is what's responsible for creating bound methods. By overriding this function, you get access to self and can create a bound copy of the call function.

The following implementation does exactly that. The Foo instance is saved in the __self__ attribute. The decorator has a __call__ method which calls the decorated function, and a call method which does the same thing.

import inspect
import functools
from copy import copy

class call:
    def __init__(self, func):
        self.func = func
        self.__self__ = None # "__self__" is also used by bound methods

    def __call__(self, *args, **kwargs):
        # if bound to on object, pass it as the first argument
        if self.__self__ is not None:
            args = (self.__self__,) + args

        return self.func(*args, **kwargs)

    def call(self, *args, **kwargs):
        self(*args, **kwargs)

    def __get__(self, obj, cls):
        if obj is None:
            return self

        # create a bound copy of the decorator
        bound = copy(self)
        bound.__self__ = obj

        # update __doc__ and similar attributes
        functools.wraps(bound.func)(bound)
        bound.__signature__ = inspect.signature(bound.func)

        # add the bound instance to the object's dict so that
        # __get__ won't be called a 2nd time
        setattr(obj, self.func.__name__, bound)

        return bound

Test:

class Foo(object):
    @call
    def bar(self):
        print('bar')

@call
def foo():
    print('foo')

Foo().bar.call() # output: bar
foo() # output: foo

Upvotes: 2

Related Questions