MattSom
MattSom

Reputation: 2377

Check which decorator was applied to a function

Following this question I have an idea how to check whether my function was decorated or not.

Only that I need further information, namely the decorators that were actually applied onto the function (or called when the function was called if it suits better).

For being safe from the danger mentioned in this answer, I am using functools.wraps. This way I don't have to be careful for any naming redefinition of the wrapper used.

This is what I have so far:

from functools import wraps

def decorator_wraps(function):
    @wraps(function)
    def _wrapper(*a, **kw): ...
    return _wrapper

def is_decorated(func):
    return hasattr(func, '__wrapped__')

@decorator_wraps
def foo(x, y): ...

print(is_decorated(foo))  # True

But what I need:

from functools import wraps

def decorator_wraps_1(function):
    @wraps(function)
    def _wrapper(*a, **kw): ...
    return _wrapper

def decorator_wraps_2(function):
    @wraps(function)
    def _wrapper(*a, **kw): ...
    return _wrapper

def decorators(func):
    # returns list of decorators on `func`

# OR

def is_decorated_by(func, decorator):
    # returns True if `func` is decorated by `decorator`

@decorator_wraps_1
@decorator_wraps_2
def foo(x, y): ...

print(decorators(foo))  # [decorator_wraps_1, decorator_wraps_2]
print(is_decorated_by(foo, decorator_wraps_1))  # True

TLDR

I want to decide if my function was decorated and I need the names of these decorator functions as well.

Any idea how to achieve this?

Upvotes: 5

Views: 1031

Answers (2)

T1nk-R
T1nk-R

Reputation: 23

Using Python 3.9.13, the following solutions work. The second option does not need you to have control over the decorators.

If you have control over the decorators, and implement a naming convention to name the wrapper function after the decorator name, such as calling the wrapper function of decorator foo foo_wrapper, you can use getmembers() to get an object the __name__ property of which contains the wrapper method's name:

from inspect import getmembers, isfunction
from typing import Callable, Any

def test():
    pipelines = get_decorated_with(tasks, 'pipeline_decorator')
    pass

def get_decorated_with(func: Any, decorator_name: str) -> List:
    pipelines = []
    
    fns = getmembers(func, isfunction)
    for (fname, fobj) in fns:
        if fobj.__name__ == f'{decorator_name}_wrapper':
            print(f'Yay, {fname} is decorated with {decorator_name}')
            pipelines.append(fname)
    
    return pipelines

Say you have a module named tasks, and a decorator called pipeline_decorator with a wrapper function pipeline_decorator_wrapper applied to a function called decorated_function():

def pipeline_decorator(func):
    def pipeline_decorator_wrapper(*args, **kwargs):
        ...
    return pipeline_decorator_wrapper

@pipeline_decorator
def decorated_function(stuff):
    ...

the above snippet will print

Yay, decorated_function is decorated with pipeline_decorator

and will return a list with one item containing 'decorated_function'.

Improved version leaving decorator definitions intact

Instead of fobj.__name__, you can use fobj.__qualname__ which contains the name of the decorator and the wrapper, too, like this: 'pipeline_decorator.<locals>.wrapper'. This way, you don't even need to have control over the decorators, you can check for them using the condition decorator_name in fobj.__qualname__:

def test(self):
    pipelines = self.get_decorated_with(tasks, 'pipeline_decorator')
    pass

def get_decorated_with(self, func: Any, decorator_name: str) -> List:
    pipelines = []
    
    fns = getmembers(func, isfunction)
    for (fname, fobj) in fns:
        if decorator_name in fobj.__qualname__:
            print(f'Yay, {fname} is decorated with {decorator_name}')
            pipelines.append(fname)
    
    return pipelines

Upvotes: 2

Daniil Fajnberg
Daniil Fajnberg

Reputation: 18458

TL;DR

Roll your own @wraps.

import functools

def update_wrapper(wrapper, wrapped, decorator, **kwargs):
    wrapper = functools.update_wrapper(wrapper, wrapped, **kwargs)
    if decorator is not None:
        __decorators__ = getattr(wrapper, "__decorators__", [])
        setattr(wrapper, "__decorators__", __decorators__ + [decorator])
    return wrapper

def wraps(wrapped, decorator, **kwargs):
    return functools.partial(
        update_wrapper, wrapped=wrapped, decorator=decorator, **kwargs
    )

def get_decorators(func):
    return getattr(func, "__decorators__", [])

def is_decorated_by(func, decorator):
    return decorator in get_decorators(func)

Usage:

def test_decorator_1(function):
    @wraps(function, test_decorator_1)
    def wrapper(*args, **kwargs):
        return function(*args, **kwargs)
    return wrapper

def test_decorator_2(function):
    @wraps(function, test_decorator_2)
    def wrapper(*args, **kwargs):
        return function(*args, **kwargs)
    return wrapper

@test_decorator_1
@test_decorator_2
def foo(x: str, y: int) -> None:
    print(x, y)

assert get_decorators(foo) == [test_decorator_2, test_decorator_1]
assert is_decorated_by(foo, test_decorator_1)

Custom @wraps

Concept

There is no built-in way for this as far as I know. All it takes to create a (functional) decorator is to define a function that takes another function as argument and returns a function. No information about that "outer" function is magically imprinted onto the returned function by virtue of decoration.

However we can lean on the functools.wraps approach and simply roll our own variation of it. We can define it in such a way that it takes not just a reference to the wrapped function as argument, but also a reference to the outer decorator.

The same way that functools.update_wrapper defines the additional __wrapped__ attribute on the wrapper it outputs, we can define our own custom __decorators__ attribute, which will be simply a list of all the decorators in the order of application (the reverse order of notation).

Code

The proper type annotations are a bit tricky, but here is a full working example:

import functools
from collections.abc import Callable
from typing import Any, ParamSpec, TypeAlias, TypeVar


P = ParamSpec("P")
T = TypeVar("T")
AnyFunc: TypeAlias = Callable[..., Any]


def update_wrapper(
    wrapper: Callable[P, T],
    wrapped: AnyFunc,
    decorator: AnyFunc | None = None,
    assigned: tuple[str, ...] = functools.WRAPPER_ASSIGNMENTS,
    updated: tuple[str, ...] = functools.WRAPPER_UPDATES,
) -> Callable[P, T]:
    """
    Same as `functools.update_wrapper`, but can also add `__decorators__`.

    If provided a `decorator` argument, it is appended to the the
    `__decorators__` attribute of `wrapper` before returning it.
    If `wrapper` has no `__decorators__` attribute, a list with just
    `decorator` in it is created and set as that attribute on `wrapper`.
    """
    wrapper = functools.update_wrapper(
        wrapper,
        wrapped,
        assigned=assigned,
        updated=updated,
    )
    if decorator is not None:
        __decorators__ = getattr(wrapper, "__decorators__", [])
        setattr(wrapper, "__decorators__", __decorators__ + [decorator])
    return wrapper


def wraps(
    wrapped: AnyFunc,
    decorator: AnyFunc | None,
    assigned: tuple[str, ...] = functools.WRAPPER_ASSIGNMENTS,
    updated: tuple[str, ...] = functools.WRAPPER_UPDATES
) -> Callable[[Callable[P, T]], Callable[P, T]]:
    """Same as `functools.wraps`, but uses custom `update_wrapper` inside."""
    return functools.partial(
        update_wrapper,  # type: ignore[arg-type]
        wrapped=wrapped,
        decorator=decorator,
        assigned=assigned,
        updated=updated,
    )


def get_decorators(func: AnyFunc) -> list[AnyFunc]:
    return getattr(func, "__decorators__", [])


def is_decorated_by(func: AnyFunc, decorator: AnyFunc) -> bool:
    return decorator in get_decorators(func)


def test() -> None:
    def test_decorator_1(function: Callable[P, T]) -> Callable[P, T]:
        @wraps(function, test_decorator_1)
        def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
            print(f"Called wrapper from {test_decorator_1.__name__}")
            return function(*args, **kwargs)
        return wrapper

    def test_decorator_2(function: Callable[P, T]) -> Callable[P, T]:
        @wraps(function, test_decorator_2)
        def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
            print(f"Called wrapper from {test_decorator_2.__name__}")
            return function(*args, **kwargs)
        return wrapper

    @test_decorator_1
    @test_decorator_2
    def foo(x: str, y: int) -> None:
        print(x, y)

    assert get_decorators(foo) == [test_decorator_2, test_decorator_1]
    assert is_decorated_by(foo, test_decorator_1)
    assert hasattr(foo, "__wrapped__")
    foo("a", 1)


if __name__ == '__main__':
    test()

The output is of course:

Called wrapper from test_decorator_1
Called wrapper from test_decorator_2
a 1

Details and caveats

With this approach, none of the original functionality of functools.wraps should be lost. Just like the original, this @wraps decorator obviously relies on you passing the correct arguments for the entire affair to make sense in the end. If you pass a nonsense argument to @wraps, it will add nonsense information to your wrapper.

The difference is you now have to provide two function references instead of one, namely the function being wrapped (as before) and the outer decorator (or None if you want to suppress that information for some reason). So you would typically use it as @wraps(function, decorator).

If you don't like that the decorator argument is mandatory, you could have it default to None. But I thought it was better this way, since the whole point is to have a consistent way of tracking who decorated whom, so omitting the decorator reference should be a conscious choice.

Note that I chose to implement __decorators__ in that order because while they are written in the reverse order, they are applied in that order. So in this example foo is decorated with @test_decorator_2 first and then the wrapper that comes out of that is decorated with @test_decorator_1. It made more sense to me for our list to reflect that order.

Static type checks

With the given type annotations mypy --strict is happy as well and any IDE should still provide the auto-suggestions as expected. The only thing that threw me off, was that mypy complained at my usage of update_wrapper as argument for functools.partial. I could not figure out, why that was, so I added a # type: ignore there.

NOTE: If you are on Python <3.10, you'll probably need to adjust the imports and take for example ParamSpec from typing_extensions instead. Also instead of T | None, you'll need to use typing.Optional[T] instead. Or upgrade your Python version. 🙂

Upvotes: 1

Related Questions