W1zard
W1zard

Reputation: 508

How to create a Python decorator that can wrap either coroutine or function?

I am trying to make a decorator to wrap either coroutines or functions.

The first thing I tried was a simple duplicate code in wrappers:

def duration(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        start_ts = time.time()
        result = func(*args, **kwargs)
        dur = time.time() - start_ts
        print('{} took {:.2} seconds'.format(func.__name__, dur))
        return result

    @functools.wraps(func)
    async def async_wrapper(*args, **kwargs):
        start_ts = time.time()
        result = await func(*args, **kwargs)
        dur = time.time() - start_ts
        print('{} took {:.2} seconds'.format(func.__name__, dur))
        return result

    if asyncio.iscoroutinefunction(func):
        return async_wrapper
    else:
        return wrapper

This works, but i want to avoid duplication of code, as this is not much better than writing two separate decorators.

Then i tried to make a decorator using class:

class SyncAsyncDuration:
    def __init__(self):
        self.start_ts = None

    def __call__(self, func):
        @functools.wraps(func)
        def sync_wrapper(*args, **kwargs):
            self.setup(func, args, kwargs)
            result = func(*args, **kwargs)
            self.teardown(func, args, kwargs)
            return result

        @functools.wraps(func)
        async def async_wrapper(*args, **kwargs):
            self.setup(func, args, kwargs)
            result = await func(*args, **kwargs)
            self.teardown(func, args, kwargs)
            return result

        if asyncio.iscoroutinefunction(func):
            return async_wrapper
        else:
            return sync_wrapper

    def setup(self, func, args, kwargs):
        self.start_ts = time.time()

    def teardown(self, func, args, kwargs):
        dur = time.time() - self.start_ts
        print('{} took {:.2} seconds'.format(func.__name__, dur))

That works in some cases very well for me, but in this solution i can't put a function in with or try statements. Is there any way i can create a decorator without duplicating code?

Upvotes: 36

Views: 19443

Answers (3)

valcapp
valcapp

Reputation: 471

In agreement with Anatoly, this solution puts together previous answers and makes sure the original type of func is preserved (if sync keeping decorated func sync, if async keeping it async):

import time
import asyncio
from contextlib import contextmanager
import functools

def decorate_sync_async(decorating_context, func):
    if asyncio.iscoroutinefunction(func):
        async def decorated(*args, **kwargs):
            with decorating_context():
                return (await func(*args, **kwargs))
    else:
        def decorated(*args, **kwargs):
            with decorating_context():
                return func(*args, **kwargs)

    return functools.wraps(func)(decorated)

@contextmanager
def wrapping_logic(func_name):
    start_ts = time.time()
    yield
    dur = time.time() - start_ts
    print('{} took {:.2} seconds'.format(func_name, dur))


def duration(func):
    timing_context = lambda: wrapping_logic(func.__name__)
    return decorate_sync_async( timing_context, func )

decorate_sync_async can now be reused with any wrapping logic (contextmanager), to create a decorator that works for both sync and async functions.

To use it (and check it):

@duration
def sync_hello():
    print('sync_hello')

@duration
async def async_hello():
    await asyncio.sleep(0.1)
    print('async_hello')

async def main():
    print(f"is {sync_hello.__name__} async? "
        f"{asyncio.iscoroutinefunction(sync_hello)}") # False
    sync_hello()

    print(f"is {async_hello.__name__} async? "
        f"{asyncio.iscoroutinefunction(async_hello)}") # True
    await async_hello()


if __name__ == '__main__':
    asyncio.run(main())

Output:

sync_hello async? False
sync_hello
sync_hello took 0.0 seconds
is async_hello async? True
async_hello
async_hello took 0.1 seconds

Upvotes: 6

Ben
Ben

Reputation: 311

For me the accepted answer by @mikhail-gerasimov was not working w/ async FastAPI methods (though it did work with normal and coroutine functions outside of FastAPI). However, I found this example on github that does work w/ fastapi methods. Adapted (slightly) below:

def duration(func):

    async def helper(func, *args, **kwargs):
        if asyncio.iscoroutinefunction(func):
            print(f"this function is a coroutine: {func.__name__}")
            return await func(*args, **kwargs)
        else:
            print(f"not a coroutine: {func.__name__}")
            return func(*args, **kwargs)

    @functools.wraps(func)
    async def wrapper(*args, **kwargs):
        start_ts = time.time()
        result = await helper(func, *args, **kwargs)
        dur = time.time() - start_ts
        print('{} took {:.2} seconds'.format(func.__name__, dur))

        return result

    return wrapper

Alternatively, if you want to keep the contextmanager, you can also do that:

def duration(func):
    """ decorator that can take either coroutine or normal function """
    @contextmanager
    def wrapping_logic():
        start_ts = time.time()
        yield
        dur = time.time() - start_ts
        print('{} took {:.2} seconds'.format(func.__name__, dur))

    @functools.wraps(func)
    async def wrapper(*args, **kwargs):
        if not asyncio.iscoroutinefunction(func):
            with wrapping_logic():
                return func(*args, **kwargs)
        else:
            with wrapping_logic():
                return (await func(*args, **kwargs))
    return wrapper

The difference between this and the accepted answer is not large. Mainly we just need to create an async wrapper and await the function if the function is a coroutine.

In my testing, this example code works in try/except blocks in your decorated function as well as with statements.

It's still not clear to me why the wrapper needs to be async for async FastAPI methods.

Upvotes: 7

Mikhail Gerasimov
Mikhail Gerasimov

Reputation: 39546

May be you can find better way to do it, but, for example, you can just move your wrapping logic to some context manager to prevent code duplication:

import asyncio
import functools
import time
from contextlib import contextmanager


def duration(func):
    @contextmanager
    def wrapping_logic():
        start_ts = time.time()
        yield
        dur = time.time() - start_ts
        print('{} took {:.2} seconds'.format(func.__name__, dur))

    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        if not asyncio.iscoroutinefunction(func):
            with wrapping_logic():
                return func(*args, **kwargs)
        else:
            async def tmp():
                with wrapping_logic():
                    return (await func(*args, **kwargs))
            return tmp()
    return wrapper

Upvotes: 22

Related Questions