Reputation: 379
I am trying to find a nicer way of statically typing the following code without the repetition:
import numpy as np
from typing import Any, Iterable, Mapping, Sequence, Tuple, Union, Protocol
Shape = Sequence[Union[int, np.int32, np.int64]]
ShapeTree = Union[Shape, Iterable['ShapeTree'], Mapping[Any, 'ShapeTree']]
class InitFn_1(Protocol):
def __call__(self, input_shape: Shape) -> Shape:
...
class InitFn_2(Protocol):
def __call__(self, input_shape: Shape) -> ShapeTree:
...
class InitFn_3(Protocol):
def __call__(self, input_shape: ShapeTree) -> Shape:
...
class InitFn_4(Protocol):
def __call__(self, input_shape: ShapeTree) -> ShapeTree:
...
InitFn = Union[InitFn_1, InitFn_2, InitFn_3, InitFn_4]
Is there a shorter way of doing this?
Upvotes: 0
Views: 138
Reputation: 3360
Note that the sum of all cases os equivalent to Callable[[Shape | ShapeTree], Shape | ShapeTree]
from typing import Any, Iterable, Mapping, Sequence, Union, Protocol
import numpy as np
Shape = Sequence[Union[int, np.int32, np.int64]]
ShapeTree = Union[Shape, Iterable["ShapeTree"], Mapping[Any, "ShapeTree"]]
class InitFn(Protocol):
def __call__(self, input_shape: Shape | ShapeTree) -> Shape | ShapeTree:
...
Upvotes: 2