Reputation: 1513
Lets say I have a 4-D numpy array (ex: np.rand((x,y,z,t))
) of data with dimensions corresponding to X,Y,Z, and time.
For each X and Y point, and at each time step, I want to find the largest index in Z for which the data is larger than some threshold n
.
So my end result should be an X-by-Y-by-t array. Instances where there are no values in the Z-column greater than the threshold should be represented by a 0.
I can loop through element-by-element and construct a new array as I go, however I am operating on a very large array and it takes too long.
Upvotes: 2
Views: 144
Reputation: 221564
Here's a faster approach -
def faster(a,n,invalid_specifier):
mask = a>n
idx = a.shape[2] - (mask[:,:,::-1]).argmax(2) - 1
idx[~mask[:,:,-1] & (idx == a.shape[2]-1)] = invalid_specifier
return idx
Runtime test -
# Using @DSM's benchmarking setup
In [553]: a = np.random.random((100,100,30,50))
...: n = 0.75
...:
In [554]: out1 = faster(a,n,invalid_specifier=0)
...: out2 = fast(a, axis=2, threshold=n) # @DSM's soln
...:
In [555]: np.allclose(out1,out2)
Out[555]: True
In [556]: %timeit fast(a, axis=2, threshold=n) # @DSM's soln
10 loops, best of 3: 64.6 ms per loop
In [557]: %timeit faster(a,n,invalid_specifier=0)
10 loops, best of 3: 43.7 ms per loop
Upvotes: 2
Reputation: 353059
Unfortunately, following the example of Python builtins, numpy doesn't make it easy to get the last index, although the first is trivial. Still, something like
def slow(arr, axis, threshold):
return (arr > threshold).cumsum(axis=axis).argmax(axis=axis)
def fast(arr, axis, threshold):
compare = (arr > threshold)
reordered = compare.swapaxes(axis, -1)
flipped = reordered[..., ::-1]
first_above = flipped.argmax(axis=-1)
last_above = flipped.shape[-1] - first_above - 1
are_any_above = compare.any(axis=axis)
# patch the no-matching-element found values
patched = np.where(are_any_above, last_above, 0)
return patched
gives me
In [14]: arr = np.random.random((100,100,30,50))
In [15]: %timeit a = slow(arr, axis=2, threshold=0.75)
1 loop, best of 3: 248 ms per loop
In [16]: %timeit b = fast(arr, axis=2, threshold=0.75)
10 loops, best of 3: 50.9 ms per loop
In [17]: (slow(arr, axis=2, threshold=0.75) == fast(arr, axis=2, threshold=0.75)).all()
Out[17]: True
(There's probably a slicker way to do the flipping but it's the end of day here and my brain is shutting down. :-)
Upvotes: 3