hm8
hm8

Reputation: 1513

Python: Find largest array index along a specific dimension which is greater than a threshold

Lets say I have a 4-D numpy array (ex: np.rand((x,y,z,t))) of data with dimensions corresponding to X,Y,Z, and time.

For each X and Y point, and at each time step, I want to find the largest index in Z for which the data is larger than some threshold n.

So my end result should be an X-by-Y-by-t array. Instances where there are no values in the Z-column greater than the threshold should be represented by a 0.

I can loop through element-by-element and construct a new array as I go, however I am operating on a very large array and it takes too long.

Upvotes: 2

Views: 144

Answers (2)

Divakar
Divakar

Reputation: 221564

Here's a faster approach -

def faster(a,n,invalid_specifier):
    mask = a>n    
    idx = a.shape[2] - (mask[:,:,::-1]).argmax(2) - 1
    idx[~mask[:,:,-1] & (idx == a.shape[2]-1)] = invalid_specifier  
    return idx

Runtime test -

# Using @DSM's benchmarking setup

In [553]: a = np.random.random((100,100,30,50))
     ...: n = 0.75
     ...: 

In [554]: out1 = faster(a,n,invalid_specifier=0)
     ...: out2 = fast(a, axis=2, threshold=n) # @DSM's soln
     ...: 

In [555]: np.allclose(out1,out2)
Out[555]: True

In [556]: %timeit fast(a, axis=2, threshold=n)  # @DSM's soln
10 loops, best of 3: 64.6 ms per loop

In [557]: %timeit faster(a,n,invalid_specifier=0)
10 loops, best of 3: 43.7 ms per loop

Upvotes: 2

DSM
DSM

Reputation: 353059

Unfortunately, following the example of Python builtins, numpy doesn't make it easy to get the last index, although the first is trivial. Still, something like

def slow(arr, axis, threshold):
    return (arr > threshold).cumsum(axis=axis).argmax(axis=axis)

def fast(arr, axis, threshold):
    compare = (arr > threshold)
    reordered = compare.swapaxes(axis, -1)
    flipped = reordered[..., ::-1]
    first_above = flipped.argmax(axis=-1)
    last_above = flipped.shape[-1] - first_above - 1
    are_any_above = compare.any(axis=axis)
    # patch the no-matching-element found values
    patched = np.where(are_any_above, last_above, 0)
    return patched

gives me

In [14]: arr = np.random.random((100,100,30,50))

In [15]: %timeit a = slow(arr, axis=2, threshold=0.75)
1 loop, best of 3: 248 ms per loop

In [16]: %timeit b = fast(arr, axis=2, threshold=0.75)
10 loops, best of 3: 50.9 ms per loop

In [17]: (slow(arr, axis=2, threshold=0.75) == fast(arr, axis=2, threshold=0.75)).all()
Out[17]: True

(There's probably a slicker way to do the flipping but it's the end of day here and my brain is shutting down. :-)

Upvotes: 3

Related Questions