Roman
Roman

Reputation: 3241

Extract numpy rows by given condition

I have numpy array as follows.

import numpy as np


data = np.array([[0,0,0,4],
                [3,0,5,0],
                 [8,9,5,3]])

print (data)

I have to extract only those lines which first three elements are not all zeros expected result is as follows:

result = np.array([[3,0,5,0],
                 [8,9,5,3]])

I tried as:

res = [l for l in data if l[:3].sum() !=0]
print (res)

It gives result. But, looking for better, numpy way of doing it.

Upvotes: 2

Views: 84

Answers (3)

yonatansc97
yonatansc97

Reputation: 689

I'll try to explaing how I think about these kinds of problems through my answer.

First step: define a function that returns a boolean indicating whether this is a good row. For that, I use np.any, which checks if any of the entries is "True" (for integers, true is non-zero).

import numpy as np
v1 = np.array([1, 1, 1, 0])
v2 = np.array([0, 0, 0, 1])
good_row = lambda v: np.any(v[:3])
good_row(v1)
Out[28]: True
good_row(v2)
Out[29]: False

Second step: I apply this on all rows, and obtain a masking vector. To do so, one can use the 'axis' keyword in 'np.any', which will apply this on columns or rows depending on the axis value.

np.any(data[:, :3], axis=1)
Out[32]: array([False,  True,  True])

Final step: I combine this with indexing, to wrap it all.

rows_inds = np.any(data[:, :3], axis=1)
data[rows_inds]
Out[37]: 
array([[3, 0, 5, 0],
       [8, 9, 5, 3]])

Upvotes: 0

Nils Werner
Nils Werner

Reputation: 36739

You say

first three elements are not all zeros

so a solution is

import numpy as np

data = np.array([[0,0,0,4],
                 [3,0,5,0],
                 [8,9,5,3]])

data[~np.all(data[:, :3] == 0, axis=1), :]

Upvotes: 0

Mad Physicist
Mad Physicist

Reputation: 114230

sum is a bit unreliable if your array can contain negative numbers, but any will always work:

result = data[data[:, :3].any(1)]

Upvotes: 5

Related Questions