ScalaBoy
ScalaBoy

Reputation: 3392

How to select indices of rows in numpy array?

I have the following numpy array y_train:

y_train =

2
2
1
0
1
1 
2
0
0

I need to randomly select n (n=2) indices of rows as follows:

n=2 
n indices of rows where y=0 
n indices of rows where y=1 
n indices of rows where y=2

I use the following code:

n=2
idx = [y_train[np.random.choice(np.where(y_train==np.unique(y_train)[I])[0],n)].index.tolist() \
 for i in np.unique(y_train).astype(int)]

Error in my real array y_train:

KeyError: '[70798 63260 64755 ...  7012 65605 45218] not in index'

Upvotes: 2

Views: 260

Answers (2)

panktijk
panktijk

Reputation: 1614

If your expected output is a list of randomly selected indices for each unique value in y_train:

idx = [np.random.choice(np.where(y_train == i)[0], size=2, \
       replace=False) for i in np.unique(y_train)]

OUTPUT:

[array([7, 8]), array([5, 4]), array([1, 0])]

If you want to flatten the arrays into a single array:

idx = np.array(idx).flatten()

OUTPUT:

array([7, 8, 5, 4, 1, 0])

Upvotes: 2

Sheldore
Sheldore

Reputation: 39072

One alternative solution to get the desired indices is using nonzero and simply looping over range(n+1)

y_train = np.array([2,2,1,0,1,1,2,0,0])

indices = [np.random.choice((y_train==i).nonzero()[0], 2, replace=False) for i in range(n+1)]
print (indices)
# [array([7, 3]), array([5, 4]), array([0, 1])]

print (np.array(indices).ravel())
# [7 3 5 4 0 1]

Upvotes: 1

Related Questions