schajan
schajan

Reputation: 99

Is there an elegant way to check if index can be requested in a numpy array?

I am looking for an elegant way to check if a given index is inside a numpy array (for example for BFS algorithms on a grid).

The following code does what I want:

import numpy as np

def isValid(np_shape: tuple, index: tuple):
    if min(index) < 0:
        return False
    for ind,sh in zip(index,np_shape):
        if ind >= sh:
            return False
    return True

arr = np.zeros((3,5))
print(isValid(arr.shape,(0,0))) # True
print(isValid(arr.shape,(2,4))) # True
print(isValid(arr.shape,(4,4))) # False

But I'd prefer something build-in or more elegant than writing my own function including python for-loops (yikes)

Upvotes: 8

Views: 1051

Answers (3)

Ricoter
Ricoter

Reputation: 763

For low dimensional problems, I would simply use

def isValid(shape : tuple, index : tuple):
    return (0 <= index[0] < shape[0] and
            0 <= index[1] < shape[1] and
            0 <= index[2] < shape[2])

It's elegant and self-explaining. Also performance-wise, this approach outperforms the alternatives. It is about 3-4 times faster than the OP's solution.

For high dimensional problems, I would generally go for

def isValid(shape : tuple, index : tuple):
    for i in range(len(shape)):
        if not (0 <= index[i] < shape[i]):
            return False
    return True

It's more elegant for high dimensions than the previous solution and just a few nano seconds slower.

Here is my benchmark (Method 0 is the OP's solution, Method 1 is my solution): benchmark test

Upvotes: 0

schajan
schajan

Reputation: 99

I have benchmarked the answers quite a bit, and come to the conclusion that actually the explicit for loop as provided in my code performs best.

Dmitri's solution is wrong for several reasons (tuple1 < tuple2 just compares the first value; ideas like np.all(ni < sh for ind,sh in zip(index,np_shape)) fail as the input to all returns a generator, not a list etc).

@mozway's solution is correct, but all the casts make it a lot slower. Also it always needs to consider all numbers for casting, while an explicit loop can stop earlier, I suppose.

Here is my benchmark (Method 0 is @mozway's solution, Method 1 is my solution):

enter image description here

Upvotes: 2

mozway
mozway

Reputation: 261860

You can try:

def isValid(np_shape: tuple, index: tuple):
    index = np.array(index)
    return (index >= 0).all() and (index < arr.shape).all()

arr = np.zeros((3,5))
print(isValid(arr.shape,(0,0))) # True
print(isValid(arr.shape,(2,4))) # True
print(isValid(arr.shape,(4,4))) # False

Upvotes: 2

Related Questions