LibrarristShalinward
LibrarristShalinward

Reputation: 91

In PyTorch, how to make __call__() of nn.Module automatically copy the type hints and docstring of forward()?

It is common to implement the forward steps in forward() and call it through __call__, like this example

from torch import nn, FloatTensor, IntTensor

class MyModule(nn.Module): 
    def __init__(self, ...) -> None: 
        nn.Module.__init__(self)
        ...
    def forward(self, x: FloatTensor, y: FloatTensor) -> tuple[FloatTensor, IntTensor]: 
        """
        Args:
            x (FloatTensor): in shape of BxTxE
            y (FloatTensor): in shape of BxE

        Returns:
            tuple[FloatTensor, IntTensor]: (sth. in shape of BxT, sth. in shape of B)
        """
        ... # implementation of forward steps

model = MyModule(...)
...
a, b = model(x, y) # call it through __call__

However, IDEs like VSCode cannot recognize the type hints or docstring of __call__ because it's a completely different method without overloading. Though this is reasonable in principle of Python, it is still unfriendly to circumstances like co-operating that needs convenient coding hints.

A possible but clumsy solution is copy those information to overload of __call__() in each nn.Module:

from torch import nn, FloatTensor, IntTensor

class MyModule(nn.Module): 
    def __init__(self, ...) -> None: 
        nn.Module.__init__(self)
        ...
    def forward(self, x: FloatTensor, y: FloatTensor) -> tuple[FloatTensor, IntTensor]: 
        """
        Args:
            x (FloatTensor): in shape of BxTxE
            y (FloatTensor): in shape of BxE

        Returns:
            tuple[FloatTensor, IntTensor]: (sth. in shape of BxT, sth. in shape of B)
        """
        ... # implementation of forward steps
    def __call__(self, x: FloatTensor, y: FloatTensor) -> tuple[FloatTensor, IntTensor]: 
        """
        Args:
            x (FloatTensor): in shape of BxTxE
            y (FloatTensor): in shape of BxE

        Returns:
            tuple[FloatTensor, IntTensor]: (sth. in shape of BxT, sth. in shape of B)
        """
        return nn.Module.__call__(self, x, y)

model = MyModule(...)
...
a, b = model(x, y) # call through __call__

So, how can I tell python or VSCode that __call__() and forward() share identical input/return types and docstring in any subclass of nn.Module, without write them again in overload of __call__() of each subclass?

(I guess possible solution for docstrings may be decorators? But I have no idea about copying type hints. )

Upvotes: 1

Views: 58

Answers (1)

LibrarristShalinward
LibrarristShalinward

Reputation: 91

Just a temporary solution, which is just a little bit better than copying hints each time, inspired by @InSync and the decorator in related question:

from typing import TypeVar, Callable
from typing_extensions import ParamSpec # for 3.7 <= python < 3.10, import from typing for versions later

T = TypeVar('T')
P = ParamSpec('P')
def take_annotation_from(this: Callable[P, T]) -> Callable[[Callable], Callable[P, T]]: 
    def decorator(real_function: Callable[P, T]) -> Callable[P, T]: 
        # typing real_function with Callable[P, T] can directly tell python the input type (at least for code hinting in VSCode)
        # so wrap the real_function like the related question is unnecessary
        real_function.__doc__ = this.__doc__
        return real_function
    return decorator

And use it as

from torch.nn import Module

class MyModule(Module): 
    def __init__(self, k: float):
        ...
    
    def forward(self, ...) -> ...: # with type hints
        """docstring"""
        ...

    @take_annotation_from(forward)
    def __call__(self, *args, **kwds):
        return Module.__call__(self, *args, **kwds)

And this solution may be proved if last three lines of the code above can be packed as something like macro, because it remains unchanged among different implementations of sub-nn.Modules.

Upvotes: 0

Related Questions