gkrupp
gkrupp

Reputation: 428

Wrapping class method with correct type hint

I am implementing run in subclasses of ParentClass, which will call run from its __call__ method.

I want the type hints from Subclass.run to apply to call sites of its __call__ method.

I would like to define the child classes (i.e. MyClass now):

class ParentClass:
    # This class can contain the nasty definitions necessary
    #   to keep child classes nice looking
    @abstractmethod
    def run(self):
        raise NotImplementedError()

class MyClass(ParentClass):
    # I don't want to use `metaclass=MetaClass` here

    # Define the logic in the method `run(...)`,
    #   signature can vary in different child classes
    # Trying to avoid the need to use the `__call__` dunder
    def run(self, a: int = 5) -> int:
        print("my", a)
        return 1

class MyClass2(ParentClass):
    def run(self, b: str = "", c: bool = False) -> int:
        print("my2", b, c)
        return 2

How I want to use the child classes:

my = MyClass()
my(a=6) # prints `my 6` and returns `1`

my2 = MyClass2()
my2(b="example", c=True) # prints `my2 example True` and returns `2`

My reason for the above is to wrap the run methods and do stuff before and after calling it.

I've tried the following:

from abc import abstractmethod
from functools import wraps
from typing import Callable, TypeVar, ParamSpec, Self, Any, cast
from __future__ import annotations

P = ParamSpec("P")
T = TypeVar("T")

class ParentMeta(type):

    def __new__(cls: Self, name: str, bases: tuple, namespace: dict[str, Any]) -> ParentMeta:

        def wrapper(func: Callable[P, T]) -> Callable[P, T]:
            @wraps(func)
            def inner(self, *args: P.args, **kwargs: P.kwargs) -> T:
                print("--pre--")
                ret = func(self, *args, **kwargs)
                print("--post--")
                return ret
            return inner
        
        if "run" in namespace:
            namespace["__call__"] = wrapper(namespace["run"])
            del namespace["run"]
        
        return cast(ParentMeta, super().__new__(cls, name, bases, namespace))
    
class ParentClass(object, metaclass=ParentMeta):
    @abstractmethod
    def run(self):
        raise NotImplementedError()

class MyClass(ParentClass):
    def run(self, a: int = 5) -> int:
        print("my", a)
        return 1

my = MyClass()
my(a=6)
# the problem here is that I lost the parameter/type hint in the IDE (I'm using VSCode)
# (so when I'm writing `my(`, I cannot see that it has an argument `a`)

If I ditch the run and define the __call__ (which I really don't want to), I do get a type hint (as it's resolved from MyClass):

from abc import abstractmethod
from functools import wraps
from typing import Callable, TypeVar, ParamSpec, Self, Any, cast
from __future__ import annotations

P = ParamSpec("P")
T = TypeVar("T")

class ParentMeta(type):

    def __new__(cls: Self, name: str, bases: tuple, namespace: dict[str, Any]) -> ParentMeta:

        def wrapper(func: Callable[P, T]) -> Callable[P, T]:
            @wraps(func)
            def inner(self, *args: P.args, **kwargs: P.kwargs) -> T:
                print("--pre--")
                ret = func(self, *args, **kwargs)
                print("--post--")
                return ret
            return inner
        
        if "__call__" in namespace:
            namespace["__call__"] = wrapper(namespace["__call__"])
        
        return cast(ParentMeta, super().__new__(cls, name, bases, namespace))
    
class ParentClass(object, metaclass=ParentMeta):
    @abstractmethod
    def __call__(self):
        raise NotImplementedError()

class MyClass(ParentClass):
    def __call__(self, a: int = 5) -> int:
        print("my", a)
        return 1

my = MyClass()
my(a=6) # got the hint `(a: int = 5) -> int` when typed

Can I write the above such way that I don't have to use __call__ in MyClass and still get a proper type hint on usage?

I'm pretty open to any simpler/more complex solutions and different version of Python (used 3.11.10).

Upvotes: 1

Views: 164

Answers (1)

aleksv
aleksv

Reputation: 71

Suggestion #1

To solve the problem with definition in type checkers, you can simply add a stub to the parent class:

class ParentClass(metaclass=ParentMeta):
    @abstractmethod
    def run(self, *args, **kwargs):
        raise NotImplementedError()

    def __call__(self, *args, **kwargs):
        ...

To improve consistency, you can add an exception call in the metaclass if the required method is not implemented:

if "run" in namespace:
    namespace["__call__"] = wrapper(namespace["run"])
else:
    def default_call(self, *args, **kwargs):
        raise NotImplementedError(f"{name} must implement the 'run' method.")
    namespace["__call__"] = default_call

cls_instance = cast(Type, super().__new__(cls, name, bases, namespace))

return cls_instance

Suggestion #2

In general, an architectural approach based only on abstract classes can be used:

P = ParamSpec("P")
T = TypeVar("T")


class CallableBase(ABC):
    def __init_subclass__(cls, **kwargs):
        super().__init_subclass__(**kwargs)
        if 'run' in cls.__dict__:
            original_run = cls.run  # noqa

            @wraps(original_run)
            def wrapped_run(self, *args: P.args, **kwargs: P.kwargs) -> T:
                print("--pre--")
                result = original_run(self, *args, **kwargs)
                print("--post--")
                return result

            cls.run = wrapped_run

    def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Any:
        if hasattr(self, 'run'):
            return self.run(*args, **kwargs)
        else:
            raise NotImplementedError(f"{self.__class__.__name__} does not implement a 'run' method.")


class ParentClass(CallableBase):
    @abstractmethod
    def run(self, *args: P.args, **kwargs: P.kwargs) -> T:
        raise NotImplementedError()


class MyClass(ParentClass):
    def run(self, a: int = 5) -> int:
        print("my", a)
        return 1

Upvotes: 0

Related Questions