Chris
Chris

Reputation: 31206

Numpy where for 2 dimensional array

I have a 2 d numpy array. I need to keep all the rows whose value at a specific column is greater than a certain number. Right now, I have:

f_left = np.where(f_sorted[:,attribute] >= split_point)

And it is failing with: "Index Error: too many indices for array"

How should I do this? I can't figure it out from the numpy website, here

Upvotes: 0

Views: 2392

Answers (2)

Garrett R
Garrett R

Reputation: 2662

You actually don't even need where.

    yy = np.array(range(12)).reshape((4,3))


    yy[yy[:,1] > 2]

Output

array([[ 3,  4,  5],
       [ 6,  7,  8],
       [ 9, 10, 11]])

Upvotes: 4

AGS
AGS

Reputation: 14498

x = np.array([[2,3,4],[5,6,7],[1,2,3],[8,9,10]])

array([[ 2,  3,  4],
       [ 5,  6,  7],
       [ 1,  2,  3],
       [ 8,  9, 10]])

Find the rows where the second element are >=4

x[np.where(x[:,1] >= 4)]

array([[ 5,  6,  7],
       [ 8,  9, 10]])

Upvotes: 3

Related Questions