Vadim
Vadim

Reputation: 89

Decorator, dependency injection and pylance typing

I'm trying to move instantiating of sqlalchemy session to decorator, because I have dozens of functions which make a lot of queries and rely on AsyncSession object. Code works fine.

But can I avoid checking "if not session"?

Can I make somehow

def with_session(func):
    async def inner(*args, **kwargs):
        async with Database.session() as session:
            await func(*args, **kwargs, session=session)

    return inner


@with_session
async def create_model(session: AsyncSession | None = None):
    if not session:
        raise
    session.add(Model())
    await session.commit()


async def main():
    await create_model()

UPDATE 10.02.2023

After making some progress with Chepner's solution, I realized that call "await create_model()" sometimes have to have argument "session" (and often some other arguments). So, I decided to make new version with using of protocols.

T = TypeVar('T', covariant=True)
P = ParamSpec('P')


class SomeType(Protocol[T, P]):
    async def __call__(self, session: AsyncSession, *args: P.args, **kwargs: P.kwargs) -> T:
        ...


class SomeTypePossible(Protocol[T, P]):
    async def __call__(self, session: AsyncSession | None = None, *args: P.args, **kwargs: P.kwargs) -> Awaitable[T]:
        ...


def with_session(func: SomeType[T, P]) -> SomeTypePossible[T, P]:
    async def inner(*args: P.args, **kwargs: P.kwargs):
        if 'session' in kwargs and (s := kwargs.pop('session')) and isinstance(s, AsyncSession):
            return await func(s, *args, **kwargs)
        else:
            async with Database.session() as session:
                return await func(session=session, *args, **kwargs)

    return inner


@with_session
async def create_model(session: AsyncSession, a: str | None = None) -> str:
    await session.commit()
    return 'test'


async def main():
    async with Database.session() as s:
        p = await create_model(session=s, a='other_test')
        print(p)

Almost everything is ok, except Pylance warning about "return inner":

Expression of type "(*args: P@with_session.args, **kwargs: P@with_session.kwargs) -> Coroutine[Any, Any, T@with_session]" cannot be assigned to return type "SomeTypePossible[T@with_session, P@with_session]"
  Type "(*args: P@with_session.args, **kwargs: P@with_session.kwargs) -> Coroutine[Any, Any, T@with_session]" cannot be assigned to type "(session: AsyncSession | None = None, **P@with_session) -> Coroutine[Any, Any, Awaitable[T@with_session]]"
    Function return type "Coroutine[Any, Any, T@with_session]" is incompatible with type "Coroutine[Any, Any, Awaitable[T@with_session]]"
      "Coroutine[Any, Any, T@with_session]" is incompatible with "Coroutine[Any, Any, Awaitable[T@with_session]]"
        Type parameter "_ReturnT_co@Coroutine" is covariant, but "T@with_session" is not a subtype of "Awaitable[T@with_session]"
          "object*" is incompatible with protocol "Awaitable[T@with_session]"Pylance

I'm not quite sure what is happening in that situation. May I ask anybody to explain? Thanks!

Upvotes: 1

Views: 293

Answers (2)

Vadim
Vadim

Reputation: 89

So my final solution is using protocols. Works good on VSCode 1.86.0, Python 3.10.13 and Pylance v2024.2.1 in strict checking mode.

T = TypeVar('T', covariant=True)
P = ParamSpec('P')


class SomeType(Protocol[T, P]):
    async def __call__(self, session: AsyncSession, *args: P.args, **kwargs: P.kwargs) -> T:
        ...


class SomeTypePossible(Protocol[T, P]):
    async def __call__(self, session: AsyncSession | None = None, *args: P.args, **kwargs: P.kwargs) -> T:
        ...


def with_session(func: SomeType[T, P]) -> SomeTypePossible[T, P]:
    async def inner(session: AsyncSession | None = None, *args: P.args, **kwargs: P.kwargs) -> T:
        if session:
            return await func(session=session, *args, **kwargs)
        else:
            async with Database.session() as session:
                return await func(session=session, *args, **kwargs)

    return inner


@with_session
async def create_task(session: AsyncSession, *, task: schemas.Task) -> models.Task:
    model = models.Task(**task.model_dump())
    ...
    return model



async def main():
   task = something()
   #if I need, I can start session here and pass await create_task(session=my_new_session, task=task)
   result = await create_task(task=task)

Upvotes: 0

chepner
chepner

Reputation: 532113

You can provide with_session with type hints that will let the type checker know that the type of create_model et al. is changed by the decorator.

from typing import Awaitable, Callable, Concatenate, ParamSpec, TypeVar


P = ParamSpec('P')
RV = TypeVar('RV')


def with_session(func: Callable[Concatenate[AsyncSession, P], Awaitable[RV]]) -> Callable[P, Awaitable[None]]:
    async def inner(*args: P.args, **kwargs: P.kwargs) -> None:
        async with Database.session() as session:
            await func(session, *args, **kwargs)

    return inner


@with_session
def create_model(session: AsyncSession):
    session.add(Model())
    await session.commit()
  • Callable[Concatenate[AsyncSession, P], Awaitable[RV]]

    The type of func indicates an arbitrary callable that takes at least an argument of type AsyncSession; if there are any other parameters, they are represented by the parameter specification P. The return type is an awaitable value that provides some arbitrary type RV.

  • Callable[P, Awaitable[None]]

    The return type of with_session is another callable, with the same parameters as func, but without the parameter of type AsyncSession and a hard-coded return value of None.

    @with_session async def create_model(session: AsyncSession): session.add(Model()) await session.commit()

Now we can define create_model as a callable that must receive an AsyncSession, so there is no need to check if session is None. The original create_model is bound to the name func in the closure returned by with_session; the new create_model is the closure, which only takes parameters captured by P (which does not include AsyncSession).

(I don't use VSCode, so I can't test this with PyLance, but it works with mypy.)

Upvotes: 1

Related Questions