Reputation: 19456
I have the following data structure:
[[[ 512 520 1 130523]]
[[ 520 614 573 7448]]
[[ 614 616 615 210]]
[[ 616 622 619 269]]
[[ 622 624 623 162]]
[[ 625 770 706 8822]]
[[ 770 776 773 241]]]
I'm trying to return an object of the same shape, but only returning the rows with the 3 largest 4th columns (if that makes sense) (so in this case, that would be rows 1, 2 & 6)
What's the most elegant way to do this?
Upvotes: 1
Views: 6972
Reputation: 879083
You could sort the array, but as of NumPy 1.8, there is a faster way to find the N largest values (particularly when data
is large):
Using numpy.argpartition:
import numpy as np
data = np.array([[[ 512, 520, 1, 130523]],
[[ 520, 614, 573, 7448]],
[[ 614, 616, 615, 210]],
[[ 616, 622, 619, 269]],
[[ 622, 624, 623, 162]],
[[ 625, 770, 706, 8822]],
[[ 770, 776, 773, 241]]])
idx = np.argpartition(-data[...,-1].flatten(), 3)
print(data[idx[:3]])
yields
[[[ 520 614 573 7448]]
[[ 512 520 1 130523]]
[[ 625 770 706 8822]]]
np.argpartition
performs a partial sort. It returns the indices of the array in partially sorted order, such that every kth
item is in its final sorted position. In effect, every group of k
items is sorted relative to the other groups, but each group itself is not sorted (thus saving some time).
Notice that the 3 highest rows are not returned in a same order as they appeared in data
.
For comparison, here is how you could find the 3 highest rows by using np.argsort
(which performs a full sort):
idx = np.argsort(data[..., -1].flatten())
print(data[idx[-3:]])
yields
[[[ 520 614 573 7448]]
[[ 625 770 706 8822]]
[[ 512 520 1 130523]]]
Note: np.argsort
is faster for small arrays:
In [63]: %timeit idx = np.argsort(data[..., -1].flatten())
100000 loops, best of 3: 2.6 µs per loop
In [64]: %timeit idx = np.argpartition(-data[...,-1].flatten(), 3)
100000 loops, best of 3: 5.61 µs per loop
But np.argpartition
is faster for large arrays:
In [92]: data2 = np.tile(data, (10**3,1,1))
In [93]: data2.shape
Out[93]: (7000, 1, 4)
In [94]: %timeit idx = np.argsort(data2[..., -1].flatten())
10000 loops, best of 3: 164 µs per loop
In [95]: %timeit idx = np.argpartition(-data2[...,-1].flatten(), 3)
10000 loops, best of 3: 49.5 µs per loop
Upvotes: 3
Reputation: 53525
I simplified the structure of your list-of-lists in order to focus on the main issue. You can use sorted()
with a customized compare()
function:
my_list = [[512, 520, 1, 130523],
[520, 614 , 573, 7448],
[614, 616, 615, 210],
[616, 622, 619, 269],
[622, 624, 623, 162],
[625, 770, 706, 8822],
[770, 776, 773, 241]]
def sort_by(a):
return a[3]
sorted(my_list, key=sort_by)
print my_list[0:3] # prints [[512, 520, 1, 130523], [520, 614, 573, 7448], [614, 616, 615, 210]]
Upvotes: 0
Reputation: 27802
You can use sorted()
and specify that you want to sort by the 4th column:
l = [[[512, 520 , 1, 130523]],
[[ 520 , 614 , 573, 7448]],
[[ 614 , 616 , 615, 210]],
[[ 616 , 622 , 619, 269]],
[[ 622 , 624 , 623, 162]],
[[ 625 , 770 , 706, 8822]],
[[ 770 , 776 , 773, 241]]]
top3 = sorted(l, key=lambda x: x[0][3], reverse=True)[:3]
print top3
will give you:
[[[512, 520, 1, 130523]], [[625, 770, 706, 8822]], [[520, 614, 573, 7448]]]
Upvotes: 5