cjm2671
cjm2671

Reputation: 19456

Sorting an array of arrays in Python

I have the following data structure:

 [[[   512    520     1 130523]]

 [[   520    614    573   7448]]

 [[   614    616    615    210]]

 [[   616    622    619    269]]

 [[   622    624    623    162]]

 [[   625    770    706   8822]]

 [[   770    776    773    241]]]

I'm trying to return an object of the same shape, but only returning the rows with the 3 largest 4th columns (if that makes sense) (so in this case, that would be rows 1, 2 & 6)

What's the most elegant way to do this?

Upvotes: 1

Views: 6972

Answers (3)

unutbu
unutbu

Reputation: 879083

You could sort the array, but as of NumPy 1.8, there is a faster way to find the N largest values (particularly when data is large):

Using numpy.argpartition:

import numpy as np
data = np.array([[[ 512,    520,     1, 130523]],
                 [[ 520,    614,    573,   7448]],
                 [[ 614,    616,    615,    210]],
                 [[ 616,    622,    619,    269]],
                 [[ 622,    624,    623,    162]],
                 [[ 625,    770,    706,   8822]],
                 [[ 770,    776,    773,    241]]])

idx = np.argpartition(-data[...,-1].flatten(), 3)
print(data[idx[:3]])

yields

[[[   520    614    573   7448]]

 [[   512    520      1 130523]]

 [[   625    770    706   8822]]]

np.argpartition performs a partial sort. It returns the indices of the array in partially sorted order, such that every kth item is in its final sorted position. In effect, every group of k items is sorted relative to the other groups, but each group itself is not sorted (thus saving some time).

Notice that the 3 highest rows are not returned in a same order as they appeared in data.


For comparison, here is how you could find the 3 highest rows by using np.argsort (which performs a full sort):

idx = np.argsort(data[..., -1].flatten())
print(data[idx[-3:]])

yields

[[[   520    614    573   7448]]

 [[   625    770    706   8822]]

 [[   512    520      1 130523]]]

Note: np.argsort is faster for small arrays:

In [63]: %timeit idx = np.argsort(data[..., -1].flatten())
100000 loops, best of 3: 2.6 µs per loop

In [64]: %timeit idx = np.argpartition(-data[...,-1].flatten(), 3)
100000 loops, best of 3: 5.61 µs per loop

But np.argpartition is faster for large arrays:

In [92]: data2 = np.tile(data, (10**3,1,1))
In [93]: data2.shape
Out[93]: (7000, 1, 4)

In [94]: %timeit idx = np.argsort(data2[..., -1].flatten())
10000 loops, best of 3: 164 µs per loop

In [95]: %timeit idx = np.argpartition(-data2[...,-1].flatten(), 3)
10000 loops, best of 3: 49.5 µs per loop

Upvotes: 3

Nir Alfasi
Nir Alfasi

Reputation: 53525

I simplified the structure of your list-of-lists in order to focus on the main issue. You can use sorted() with a customized compare() function:

my_list =  [[512, 520, 1, 130523], 
        [520, 614 , 573, 7448],
        [614, 616, 615, 210],
        [616, 622, 619, 269], 
        [622, 624, 623, 162], 
        [625, 770, 706, 8822], 
        [770, 776, 773, 241]]

def sort_by(a):
    return a[3]

sorted(my_list, key=sort_by)
print my_list[0:3] # prints [[512, 520, 1, 130523], [520, 614, 573, 7448], [614, 616, 615, 210]]

Upvotes: 0

jh314
jh314

Reputation: 27802

You can use sorted() and specify that you want to sort by the 4th column:

l = [[[512,    520 ,    1, 130523]],
 [[   520 ,   614  ,  573,   7448]],
 [[   614 ,   616  ,  615,    210]],
 [[   616 ,   622  ,  619,    269]],
 [[   622 ,   624  ,  623,    162]],
 [[   625 ,   770  ,  706,   8822]],
 [[   770 ,   776  ,  773,    241]]]

top3 =  sorted(l, key=lambda x: x[0][3], reverse=True)[:3]

print top3

will give you:

[[[512, 520, 1, 130523]], [[625, 770, 706, 8822]], [[520, 614, 573, 7448]]]

Upvotes: 5

Related Questions