Reputation: 29159
I created the following decorator for both async/coroutine and sync functions.
def authorize(role):
def decorator(f):
@contextmanager
def auth(): # Business logic shared by async and sync funtions
if is_authorized(role): # role admin returns True, otherwise False
yield
else:
print('ERROR')
@wraps(f)
def wrapper(*args, **kwargs):
if not asyncio.iscoroutinefunction(f):
with auth():
return f(*args, **kwargs)
else:
async def tmp():
with auth():
return (await f(*args, **kwargs))
return tmp()
return wrapper
return decorator
It works well if is_authorized()
returns True.
@authorize(role='Readonly')
def test():
print('TEST')
test()
However, it will raise exception when is_authorized()
returns False. The decorated function shouldn't be called if it not authorized, it should return 501 HTTP error.
@authorize(role='Readonly')
def test():
print('TEST')
ERROR Traceback (most recent call last): File "", line 1, in File "", line 13, in wrapper File "C:\anaconda3\lib\contextlib.py", line 115, in __enter__ raise RuntimeError("generator didn't yield") from None RuntimeError: generator didn't yield
Upvotes: 1
Views: 1354
Reputation: 1434
The error raised because contextmanager
must be the generator, it means it must always execute yield
statement because yield
separates the __enter__
and __exit__
parts of contextmanager. In your implementation it yields only if is_autorized
returns True
.
And actually you don't need the contexmanager
here, you need simple if
statement.
I'm passing is_authorized
through arguments because it's useful for injecting alternative implementations for testing or other purposes.
import asyncio
import functools
def authorize(role, is_authorized):
def decorator(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
if is_authorized(role):
if asyncio.iscoroutinefunction(f):
async def tmp():
return (await f(*args, **kwargs))
return tmp()
else:
return f(*args, **kwargs)
elif asyncio.iscoroutinefunction(f):
# must return coro anyway
async def tmp():
print("async unauthorized")
return None
return tmp()
else:
print("sync unauthorized")
return None
return wrapper
return decorator
def is_authorized(role):
return role == "lucky"
@authorize("lucky", is_authorized)
async def func1():
await asyncio.sleep(0)
return "coro finished"
@authorize("whatever", is_authorized)
async def func2():
await asyncio.sleep(0)
return "coro would not called"
@authorize("lucky", is_authorized)
def func3():
return "sync func finished"
@authorize("whatever", is_authorized)
def func4():
return "would not called"
if __name__ == "__main__":
print(asyncio.run(func1()))
print(asyncio.run(func2()))
print(func3())
print(func4())
prints
coro finished
async unauthorized
None
sync func finished
sync unauthorized
None
Upvotes: 2