Reputation: 91
It is common to implement the forward steps in forward()
and call it through __call__
, like this example
from torch import nn, FloatTensor, IntTensor
class MyModule(nn.Module):
def __init__(self, ...) -> None:
nn.Module.__init__(self)
...
def forward(self, x: FloatTensor, y: FloatTensor) -> tuple[FloatTensor, IntTensor]:
"""
Args:
x (FloatTensor): in shape of BxTxE
y (FloatTensor): in shape of BxE
Returns:
tuple[FloatTensor, IntTensor]: (sth. in shape of BxT, sth. in shape of B)
"""
... # implementation of forward steps
model = MyModule(...)
...
a, b = model(x, y) # call it through __call__
However, IDEs like VSCode cannot recognize the type hints or docstring of __call__
because it's a completely different method without overloading.
Though this is reasonable in principle of Python, it is still unfriendly to circumstances like co-operating that needs convenient coding hints.
A possible but clumsy solution is copy those information to overload of __call__()
in each nn.Module
:
from torch import nn, FloatTensor, IntTensor
class MyModule(nn.Module):
def __init__(self, ...) -> None:
nn.Module.__init__(self)
...
def forward(self, x: FloatTensor, y: FloatTensor) -> tuple[FloatTensor, IntTensor]:
"""
Args:
x (FloatTensor): in shape of BxTxE
y (FloatTensor): in shape of BxE
Returns:
tuple[FloatTensor, IntTensor]: (sth. in shape of BxT, sth. in shape of B)
"""
... # implementation of forward steps
def __call__(self, x: FloatTensor, y: FloatTensor) -> tuple[FloatTensor, IntTensor]:
"""
Args:
x (FloatTensor): in shape of BxTxE
y (FloatTensor): in shape of BxE
Returns:
tuple[FloatTensor, IntTensor]: (sth. in shape of BxT, sth. in shape of B)
"""
return nn.Module.__call__(self, x, y)
model = MyModule(...)
...
a, b = model(x, y) # call through __call__
So, how can I tell python or VSCode that __call__()
and forward()
share identical input/return types and docstring in any subclass of nn.Module
, without write them again in overload of __call__()
of each subclass?
(I guess possible solution for docstrings may be decorators? But I have no idea about copying type hints. )
Upvotes: 1
Views: 58
Reputation: 91
Just a temporary solution, which is just a little bit better than copying hints each time, inspired by @InSync and the decorator in related question:
from typing import TypeVar, Callable
from typing_extensions import ParamSpec # for 3.7 <= python < 3.10, import from typing for versions later
T = TypeVar('T')
P = ParamSpec('P')
def take_annotation_from(this: Callable[P, T]) -> Callable[[Callable], Callable[P, T]]:
def decorator(real_function: Callable[P, T]) -> Callable[P, T]:
# typing real_function with Callable[P, T] can directly tell python the input type (at least for code hinting in VSCode)
# so wrap the real_function like the related question is unnecessary
real_function.__doc__ = this.__doc__
return real_function
return decorator
And use it as
from torch.nn import Module
class MyModule(Module):
def __init__(self, k: float):
...
def forward(self, ...) -> ...: # with type hints
"""docstring"""
...
@take_annotation_from(forward)
def __call__(self, *args, **kwds):
return Module.__call__(self, *args, **kwds)
And this solution may be proved if last three lines of the code above can be packed as something like macro, because it remains unchanged among different implementations of sub-nn.Module
s.
Upvotes: 0