bad_keypoints
bad_keypoints

Reputation: 1400

NumPy: removing rows in an array if one column's value does not match

I have two arrays in NumPy:

a1 = 
array([[ 262.99182129,  213.        ,    1.        ],
       [ 311.98925781,  271.99050903,    2.        ],
       [ 383.        ,  342.        ,    3.        ],
       [ 372.16494751,  348.83505249,    4.        ],
       [ 214.55493164,  137.01008606,    5.        ],
       [ 138.29714966,  199.75      ,    6.        ],
       [ 289.75      ,  220.75      ,    7.        ],
       [ 239.        ,  279.        ,    8.        ],
       [ 130.75      ,  348.25      ,    9.        ]])

a2 = 
array([[ 265.78259277,  212.99705505,    1.        ],
       [ 384.23312378,  340.99707031,    3.        ],
       [ 373.66967773,  347.96688843,    4.        ],
       [ 217.91461182,  137.2791748 ,    5.        ],
       [ 141.35340881,  199.38366699,    6.        ],
       [ 292.24401855,  220.83808899,    7.        ],
       [ 241.53366089,  278.56951904,    8.        ],
       [ 133.26490784,  347.14279175,    9.        ]])

Actually there will be thousands of rows.

But as you can see, the third column in a2 does not have the value 2.0.

What I simply want is to remove from a1 the rows whose 3rd column values are not found in any row of a2.

What's the NumPy way/shortcut to do this fast?

Upvotes: 0

Views: 308

Answers (1)

Alex Riley
Alex Riley

Reputation: 176820

One option is to use np.in1d to check whether each of the values in column 2 of a1 is in column 2 of a2 and use the resulting Boolean array to index the rows of a1.

You can do this as follows:

>>> a1[np.in1d(a1[:, 2], a2[:, 2])]
array([[ 262.99182129,  213.        ,    1.        ],
       [ 383.        ,  342.        ,    3.        ],
       [ 372.16494751,  348.83505249,    4.        ],
       [ 214.55493164,  137.01008606,    5.        ],
       [ 138.29714966,  199.75      ,    6.        ],
       [ 289.75      ,  220.75      ,    7.        ],
       [ 239.        ,  279.        ,    8.        ],
       [ 130.75      ,  348.25      ,    9.        ]])

The row in a1 with 2 in the third column in not in this array as required.

Upvotes: 2

Related Questions