Curious2learn
Curious2learn

Reputation: 33598

Select elements from a numpy array based on values in another array that is not an index array

Suppose I have the following two arrays:

a = array([(1, 'L', 74.423088306605), (5, 'H', 128.05441039929008),
       (2, 'L', 68.0581377353869), (0, 'H', 88.15726964130869), 
       (4, 'L', 97.4501582588212), (3, 'H', 92.98550136344437),
       (7, 'L', 87.75945631669309), (6, 'L', 90.43196739694255),
       (8, 'H', 111.13662092749307), (15, 'H', 91.44444608631304),
       (10, 'L', 85.43615908319185), (11, 'L', 78.11685661303494),
       (13, 'H', 108.2841293816308), (17, 'L', 74.43917911042259),
       (14, 'H', 64.41057325770373), (9, 'L', 27.407214746467943),
       (16, 'H', 81.50506434964355), (12, 'H', 97.79700070323196),
       (19, 'L', 51.139258140713025), (18, 'H', 118.34835768605957)], 
      dtype=[('id', '<i4'), ('name', 'S1'), ('value', '<f8')])

b = array([ 0,  3,  5,  8, 12, 13, 14, 15, 16, 18], dtype=int32)

I want to select elements from a for which the id is given in b. That is, b is not an index array. It contains the ids of the observations. How can I do this in numpy?

Thanks for the help.

Upvotes: 5

Views: 10917

Answers (3)

Jaime
Jaime

Reputation: 67417

The following works several times faster than Francesco's approach for your sample array:

In [7]: a[np.argmax(a['id'][None, :] == b[:, None], axis=1)]
Out[7]: 
array([(0, 'H', 88.15726964130869), (3, 'H', 92.98550136344437),
       (5, 'H', 128.05441039929008), (8, 'H', 111.13662092749307),
       (12, 'H', 97.79700070323196), (13, 'H', 108.2841293816308),
       (14, 'H', 64.41057325770373), (15, 'H', 91.44444608631304),
       (16, 'H', 81.50506434964355), (18, 'H', 118.34835768605957)], 
      dtype=[('id', '<i4'), ('name', '|S1'), ('value', '<f8')])

In [8]: %timeit a[np.argmax(a['id'][None, :] == b[:, None], axis=1)]
100000 loops, best of 3: 11.6 us per loop

In [9]: %timeit indices = [i for i,id in enumerate(a['id']) if id in b]; a[indices]
10000 loops, best of 3: 66.9 us per loop

To understand how it works, take a look at this:

In [10]: a['id'][None, :] == b[:, None]
Out[10]: 
array([[False, False, False,  True, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False],
    ... # several rows removed 
    [False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False,  True]], dtype=bool)

It is an array of as many rows as elements in b and as many columns as elements in a. np.argmax then finds the position of the first True in every row, which is the index of the first appearance of the corresponding element of b in a['id'].

As shown above, for small arrays this beats python performance-wise. But if either a or b get too big, then the size of the intermediate array of bools can cripple performance. Also, np.argmax has to search the full row, it never breaks out of the loop early, which is not a good thing if a is too long. I did some timings in an answer to this question that uses a similar approach, and there it was still the way to go for moderately large arrays.

Francesco's approach is definitely less hacky, easier to understand, and for an array the size of your sample the performance differences are irrelevant, I must admit. But it doesn't make you feel as much like this...

Upvotes: 5

rosjo
rosjo

Reputation: 33

sorted = numpy.sort(a)
sorted[b]
 array([(0, 'H', 88.15726964130869), (3, 'H', 92.98550136344437),
   (5, 'H', 128.05441039929008), (8, 'H', 111.13662092749307),
   (12, 'H', 97.79700070323196), (13, 'H', 108.2841293816308),
   (14, 'H', 64.41057325770373), (15, 'H', 91.44444608631304),
   (16, 'H', 81.50506434964355), (18, 'H', 118.34835768605957)], 
  dtype=[('id', '<i4'), ('name', '|S1'), ('value', '<f8')])

As long there are as many ids as rows in the array.

Upvotes: 0

Francesco Montesano
Francesco Montesano

Reputation: 8658

you should get what you want with this

indeces = [i for i,id in enumerate(a['id']) if id in b]
suba = a[indeces]
print(suba)
>>>array([(5, 'H', 128.05441039929008), (0, 'H', 88.15726964130869),
   (3, 'H', 92.98550136344437), (8, 'H', 111.13662092749307),
   (15, 'H', 91.44444608631304), (13, 'H', 108.2841293816308),
   (14, 'H', 64.41057325770373), (16, 'H', 81.50506434964355),
   (12, 'H', 97.79700070323196), (18, 'H', 118.34835768605957)], 
  dtype=[('id', '<i4'), ('name', '|S1'), ('value', '<f8')])

Upvotes: 6

Related Questions