Miguel Monteiro
Miguel Monteiro

Reputation: 379

Python typing with recursive types and protocols

I am trying to find a nicer way of statically typing the following code without the repetition:

import numpy as np
from typing import Any, Iterable, Mapping, Sequence, Tuple, Union, Protocol

Shape = Sequence[Union[int, np.int32, np.int64]]
ShapeTree = Union[Shape, Iterable['ShapeTree'], Mapping[Any, 'ShapeTree']]

class InitFn_1(Protocol):
    def __call__(self, input_shape: Shape) -> Shape:
        ...

class InitFn_2(Protocol):
    def __call__(self, input_shape: Shape) -> ShapeTree:
        ...

class InitFn_3(Protocol):
    def __call__(self, input_shape: ShapeTree) -> Shape:
        ...

class InitFn_4(Protocol):
    def __call__(self, input_shape: ShapeTree) -> ShapeTree:
        ...

InitFn = Union[InitFn_1, InitFn_2, InitFn_3, InitFn_4]

Is there a shorter way of doing this?

Upvotes: 0

Views: 138

Answers (1)

Paweł Rubin
Paweł Rubin

Reputation: 3360

Note that the sum of all cases os equivalent to Callable[[Shape | ShapeTree], Shape | ShapeTree]

from typing import Any, Iterable, Mapping, Sequence, Union, Protocol

import numpy as np

Shape = Sequence[Union[int, np.int32, np.int64]]
ShapeTree = Union[Shape, Iterable["ShapeTree"], Mapping[Any, "ShapeTree"]]


class InitFn(Protocol):
    def __call__(self, input_shape: Shape | ShapeTree) -> Shape | ShapeTree:
        ...

Upvotes: 2

Related Questions