Reputation: 374
I have recently started using pyright in strict mode to type-check my code. It works perfectly except in one specific edge case using numpy.
>> pip list | grep numpy
numpy 2.1.0
>> pip list | grep pyright
pyright 1.1.377
If I have a function like that
def str_to_int_array(
foo: str
) -> np.ndarray[tuple[int, int], np.dtype[np.int_]]:
3D_int_array = batch_str_to_int_array([foo]) # returns np.ndarray[tuple[int, int, int], np.dtype[np.int_]]
2D_int_array = seq_onehot[0] # SHOULD return np.ndarray[tuple[int, int], np.dtype[np.int_]] but it actually is Any
return 2D_int_array
Pyright still runs the code perfectly but doesn't acknowledge the return type to be np.ndarray[tuple[int, int], np.dtype[np.int_]]
. Instead, it says that 2D_int_array
is of Any type
I am under the impression that pyright should be able to infer the type of 2D_int_array
as np.ndarray[tuple[int, int], np.dtype[np.int_]]
Why is that not the case ? I understand there are edge cases with ndarray (some of them can be arrays, others can be numerical values etc) but in that specific case of calling __get_item__
, shouldn't the type be obvious i.e removing one shape dimension and keeping the same elements dtype.
Am I doing this the wrong way ?
Thanks a lot.
Upvotes: 0
Views: 64