Aydin Abiar
Aydin Abiar

Reputation: 374

How to infer ndarray typing after indexing using pyright?

I have recently started using pyright in strict mode to type-check my code. It works perfectly except in one specific edge case using numpy.

>> pip list | grep numpy
numpy                     2.1.0
>> pip list | grep pyright
pyright                   1.1.377

If I have a function like that

def str_to_int_array(
    foo: str
) -> np.ndarray[tuple[int, int], np.dtype[np.int_]]:
    3D_int_array = batch_str_to_int_array([foo]) # returns np.ndarray[tuple[int, int, int], np.dtype[np.int_]]
    2D_int_array = seq_onehot[0]  # SHOULD return np.ndarray[tuple[int, int], np.dtype[np.int_]] but it actually is Any
    return 2D_int_array

Pyright still runs the code perfectly but doesn't acknowledge the return type to be np.ndarray[tuple[int, int], np.dtype[np.int_]]. Instead, it says that 2D_int_array is of Any type

I am under the impression that pyright should be able to infer the type of 2D_int_array as np.ndarray[tuple[int, int], np.dtype[np.int_]]

Why is that not the case ? I understand there are edge cases with ndarray (some of them can be arrays, others can be numerical values etc) but in that specific case of calling __get_item__, shouldn't the type be obvious i.e removing one shape dimension and keeping the same elements dtype.

Am I doing this the wrong way ?

Thanks a lot.

Upvotes: 0

Views: 64

Answers (0)

Related Questions