Reputation: 99
I am looking for an elegant way to check if a given index is inside a numpy array (for example for BFS algorithms on a grid).
The following code does what I want:
import numpy as np
def isValid(np_shape: tuple, index: tuple):
if min(index) < 0:
return False
for ind,sh in zip(index,np_shape):
if ind >= sh:
return False
return True
arr = np.zeros((3,5))
print(isValid(arr.shape,(0,0))) # True
print(isValid(arr.shape,(2,4))) # True
print(isValid(arr.shape,(4,4))) # False
But I'd prefer something build-in or more elegant than writing my own function including python for-loops (yikes)
Upvotes: 8
Views: 1051
Reputation: 763
For low dimensional problems, I would simply use
def isValid(shape : tuple, index : tuple):
return (0 <= index[0] < shape[0] and
0 <= index[1] < shape[1] and
0 <= index[2] < shape[2])
It's elegant and self-explaining. Also performance-wise, this approach outperforms the alternatives. It is about 3-4 times faster than the OP's solution.
For high dimensional problems, I would generally go for
def isValid(shape : tuple, index : tuple):
for i in range(len(shape)):
if not (0 <= index[i] < shape[i]):
return False
return True
It's more elegant for high dimensions than the previous solution and just a few nano seconds slower.
Here is my benchmark (Method 0 is the OP's solution, Method 1 is my solution):
Upvotes: 0
Reputation: 99
I have benchmarked the answers quite a bit, and come to the conclusion that actually the explicit for loop as provided in my code performs best.
Dmitri's solution is wrong for several reasons (tuple1 < tuple2 just compares the first value; ideas like np.all(ni < sh for ind,sh in zip(index,np_shape)) fail as the input to all returns a generator, not a list etc).
@mozway's solution is correct, but all the casts make it a lot slower. Also it always needs to consider all numbers for casting, while an explicit loop can stop earlier, I suppose.
Here is my benchmark (Method 0 is @mozway's solution, Method 1 is my solution):
Upvotes: 2
Reputation: 261860
You can try:
def isValid(np_shape: tuple, index: tuple):
index = np.array(index)
return (index >= 0).all() and (index < arr.shape).all()
arr = np.zeros((3,5))
print(isValid(arr.shape,(0,0))) # True
print(isValid(arr.shape,(2,4))) # True
print(isValid(arr.shape,(4,4))) # False
Upvotes: 2