MacFreek
MacFreek

Reputation: 3446

Python type checking on a function that returns str or Sequence

I have a Python function that processes a sequence, and returns the same sort of sequence. E.g. if it is fed a list of integers, it will return a list of integers, and if it is fed a string, it will return a string.

How do I add type hints for this function?

The following works fine, but is not a very strict type check:

from typing import Any, TypeVar, Sequence

_S = Any

def process_sequence(s: _S) -> _S:
    return s

def print_str(word: str):
    print(word)

def print_sequence_of_ints(ints: Sequence[int]):
    for i in ints:
        print(i)

a = process_sequence("spam")
print_str(a)

a = process_sequence([1,2,3,42])
print_sequence_of_ints(a)

However, when I try to narrow down _S:

_S = Sequence

or

_S = TypeVar('_S', Sequence, str)

mypy (a type checking code validator) yields the following error:

error: Argument 1 to "print_str" has incompatible type "Sequence[Any]"; expected "str"

How can I add a type hint to my function that says the input must be a sequence, and the output has the same type as the input, and make mypy happy at the same time?

Upvotes: 2

Views: 2329

Answers (2)

MacFreek
MacFreek

Reputation: 3446

I have found a solution:

from typing import TypeVar, Sequence

_S = TypeVar('_S', str, Sequence)

def process_sequence(s: _S) -> _S:
    return s

def print_str(word: str):
    print(word)

def print_sequence_of_ints(ints: Sequence[int]):
    for i in ints:
        print(i)

a = process_sequence("spam")
print_str(a)

b = process_sequence([1,2,3,42])
print_sequence_of_ints(b)

In this case, mypy is happy. Apparently, in the TypeVar declaration I have to define the more specific str before the more generic Sequence. (_S = TypeVar('_S', Sequence, str) still gives an error)

I also tried to educate mypy using a # type: str comment, but that did not work.

Upvotes: 3

TwistedSim
TwistedSim

Reputation: 2030

a is of type Sequence[Any] not str. If you are sure that a will always be a string, you can cast the type with print_str(cast(str, a)).

_S = Sequence[Any]

def process_sequence(s: _S) -> _S:
    return s

def print_str(word: str):
    print(word)

def print_sequence_of_ints(ints: Sequence[int]):
    for i in ints:
        print(i)

a = process_sequence("spam")
print_str(a)  # a is of type Sequence[Any] not str

a = process_sequence([1,2,3,42])
print_sequence_of_ints(a)

You can also use T = TypeVar('T') instead of Sequence[Any], but you lose some typing information and protection.

Upvotes: 1

Related Questions