MrBean Bremen
MrBean Bremen

Reputation: 16815

How to correctly type-annotate functions with variable types?

I'm trying to add type hints to a file system related library, where a lot of functions take a path that is either of str or bytes type. I can handle my own functions by using overloads, but I'm struggling to handle simple operations or standard library functions that are called inside with arguments of either type. Here is a simplified example:

@overload
def join_paths(s1: str, s2: str) -> str: ...


@overload
def join_paths(s1: bytes, s2: bytes) -> bytes: ...


def join_paths(s1: Union[str, bytes],
               s2: Union[str, bytes]) -> Union[str, bytes]:
    return s1 + s2

The overloads work fine if I want to call this function from elsewhere, but my problem is with the s1 + s2 statement, which causes mypy to issue the warnings:

example.py:74: error: Unsupported operand types for + ("str" and "bytes")  [operator]
example.py:74: error: Unsupported operand types for + ("bytes" and "str")  [operator]

What I want to express is that either both operands are of type str or both are of bytes type, similar to what is done to my own function using the overloads.

I don't have much experience with typing, so I may just miss the obvious solution, but so far I haven't found how to adapt this to avoid the warnings.

Upvotes: 5

Views: 852

Answers (2)

Paul Lemarchand
Paul Lemarchand

Reputation: 2096

typing.AnyStr is the best fit for this specific case.

From the documentation:

It is meant to be used for functions that may accept any kind of string without allowing different kinds of strings to mix. For example:

def concat(a: AnyStr, b: AnyStr) -> AnyStr:
    return a + b

concat(u"foo", u"bar")  # Ok, output has type 'unicode'
concat(b"foo", b"bar")  # Ok, output has type 'bytes'
concat(u"foo", b"bar")  # Error, cannot mix unicode and bytes

Therefore, you can modify your code as such:

from typing import AnyStr


def join_paths(s1: AnyStr, s2: AnyStr) -> AnyStr:
    return s1 + s2

join_paths("s1", "s2")  # OK
join_paths(b"s1", b"s2")  # OK
join_paths("s1", b"s2")  # error: Value of type variable "AnyStr" of "join_paths" cannot be "object"

Upvotes: 2

Samwise
Samwise

Reputation: 71454

Use a TypeVar:

from typing import TypeVar

T = TypeVar('T', str, bytes)


def join_paths(s1: T, s2: T) -> T:
    return s1 + s2


join_paths("foo", "bar")    # fine
join_paths(b"foo", b"bar")  # fine
join_paths(1, 2)            # error: T can't be int
join_paths("foo", b"bar")   # error: T can't be object

Overloading is more of a tool of last resort when you can't express type relationships via TypeVars and Generics -- using overloads effectively usually involves a lot of runtime type assertions (or #type: ignores) in the body of a loosely-typed implementation.

Upvotes: 8

Related Questions