Reputation: 33
I am trying to create a class as a decorator that will apply a try-except block to the decorated function and retain some log of the exceptions. I want to apply the decorator to both regular functions as well as to coroutines.
I have done the class-as-decorator and it works as designed for regular functions, but something is going wrong with coroutines. Below is some reduced code for a reduced version of the class-as-decorator and a couple of use cases:
import traceback
import asyncio
import functools
class Try:
def __init__(self, func):
functools.update_wrapper(self, func)
self.func = func
def __call__(self, *args, **kwargs):
print(f"applying __call__ to {self.func.__name__}")
try:
return self.func(*args, **kwargs)
except:
print(f"{self.func.__name__} failed")
print(traceback.format_exc())
def __await__(self, *args, **kwargs):
print(f"applying __await__ to {self.func.__name__}")
try:
yield self.func(*args, **kwargs)
except:
print(f"{self.func.__name__} failed")
print(traceback.format_exc())
# Case 1
@Try
def times2(x):
return x*2/0
# Case 2
@Try
async def times3(x):
await asyncio.sleep(0.0001)
return x*3/0
async def test_try():
return await times3(10)
def main():
times2(10)
asyncio.run(test_try())
print("All done")
if __name__ == "__main__":
main()
Here's the output of the above code (with minor edits):
applying __call__ to times2
times2 failed
Traceback (most recent call last):
File "<ipython-input-3-37071526b2e6>", line 14, in __call__
return self.func(*args, **kwargs)
File "<ipython-input-3-37071526b2e6>", line 30, in times2
return x*2/0
ZeroDivisionError: division by zero
applying __call__ to times3
Traceback (most recent call last):
File "[...]/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3296, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-3-37071526b2e6>", line 46, in <module>
main()
File "<ipython-input-3-37071526b2e6>", line 43, in main
asyncio.run(test_try())
File "[...]/lib/python3.7/asyncio/runners.py", line 43, in run
return loop.run_until_complete(main)
File "[...]/lib/python3.7/asyncio/base_events.py", line 579, in run_until_complete
return future.result()
File "<ipython-input-3-37071526b2e6>", line 39, in test_try
return await times3(10)
File "<ipython-input-3-37071526b2e6>", line 36, in times3
return x*3/0
ZeroDivisionError: division by zero
Case 1 behaves normally: as expected __call__
is called, then the decorated function which fails and the exception is caught. But I can't explain the behaviour of Case 2. Notice the missing "times3 failed" and "All done" print at the end. I can't reproduce the color coded output here but Case 1's traceback is regular print while Case 2's traceback is exception red (on PyCharm). The surprising part is that the __call__
method was called instead of __await__
.
I have tried another class-as-decorator, one that keeps a tally of the number of times a function was called. That works just fine with __call__
with either regular functions or coroutines.
So what is actually going on? Do I need to somehow force the function to use __await__
? How?
I tried the following:
async def test_try2():
func = await times3
with output
applying __await__ to times3
times3 failed
Traceback (most recent call last):
File "<ipython-input-5-5a85f988097e>", line 22, in __await__
yield self.func(*args, **kwargs)
TypeError: times3() missing 1 required positional argument: 'x'
which does force using __await__
but then what?
Upvotes: 2
Views: 460
Reputation: 154886
The problem with your code is that it places __await__
on the wrong object. Generally await f(x)
expands to something like:
_awaitable = f(x)
_iter = _awaitable.__await__()
yield from _iter # not literally[1]
Note how __await__()
is called on the result on the function, not on the function object itself. What happens in your times3
example is the following:
__call__
calls the original times3
coroutine function in self.func
, which trivially constructs a coroutine object. There is no exception at this point because the object didn't start executing yet, so a coroutine object (what you get by calling an async def
coroutine function) is returned.
__await__
is invoked on the coroutine object obtained by running self.func
, which is the original times3
async def
, and not on your function wrapper. This is because, in terms of the pseudocode above, your wrapper corresponds to f
, whereas __await__()
is invoked on the _awaitable
, which in your case is the result of calling f
.
In general you can't know whether the result of a function call will ever be awaited. But since coroutine objects are not useful for anything other than awaiting them (and they even print a warning when destroyed without being awaited), you can safely assume so. This assumption allows your __call__
to check whether the result of the function call is awaitable and, if so, wrap it in an object that will implement your wrapping logic on the __await__
level:
...
import collections.abc
class Try:
def __init__(self, func):
functools.update_wrapper(self, func)
self.func = func
def __call__(self, *args, **kwargs):
print(f"applying __call__ to {self.func.__name__}")
try:
result = self.func(*args, **kwargs)
except:
print(f"{self.func.__name__} failed")
print(traceback.format_exc())
return
if isinstance(result, collections.abc.Awaitable):
# The result is awaitable, wrap it in an object
# whose __await__ will call result.__await__()
# and catch the exceptions.
return TryAwaitable(result)
return result
class TryAwaitable:
def __init__(self, awaitable):
self.awaitable = awaitable
def __await__(self, *args, **kwargs):
print(f"applying __await__ to {self.awaitable.__name__}")
try:
return yield from self.awaitable.__await__()
except:
print(f"{self.awaitable.__name__} failed")
print(traceback.format_exc())
This results in the expected output:
applying __call__ to times3
applying __await__ to times3
times3 failed
Traceback (most recent call last):
File "wrap3.py", line 30, in __await__
yield from self.awaitable.__await__()
File "wrap3.py", line 44, in times3
return x*3/0
ZeroDivisionError: division by zero
Note that your implementation of __await__
had an unrelated problem, it delegated to the function using yield
. One must use yield from
instead, because that allows the underlying iterable to choose when to suspend, and also to provide a value once it stops suspending. A bare yield
suspends unconditionally (and only once) which is incompatible with the semantics of await
.
1
Not literally because yield from
is not allowed in async def
. But async def
behaves as if such a generator was returned by the __await__
method of the object it returns.
Upvotes: 1