Octoplus
Octoplus

Reputation: 483

How to find the column index of the value of a given row in numpy?

I want to perform something relatively simple in numpy:

However I ended up with a pretty complex code:

predictions = np.array([[1,-1,-1,-1],[-1,1,1,-1],[-1,-1,-1,1],[-1,-1,-1,-1]])

one_count = (predictions == 1).sum(1)
valid_rows_idx = np.where(one_count==1)

result = np.zeros(predictions.shape[0])
for idx in valid_rows_idx:
    result[idx] = np.where(predictions[idx,:]==1)[1] + 1

If I print result, the program prints [ 1. 0. 4. 0.] which is the desired result.

I was wondering if there is a simpler way of writing the last line using numpy.

Upvotes: 3

Views: 1039

Answers (1)

Anton Protopopov
Anton Protopopov

Reputation: 31672

I'm not sure is that better or not but you could try to use argmax for that. Also you don't need to use for loop and np.where to get valid indices:

predictions = np.array([[1,-1,-1,-1],[-1,1,1,-1],[-1,-1,-1,1],[-1,-1,-1,-1]])

idx = (predictions == 1).sum(1) == 1
result = np.zeros(predictions.shape[0])
result[idx] = (predictions[idx]==1).argmax(axis=1) + 1

In [55]: result
Out[55]: array([ 1.,  0.,  4.,  0.])

Or you could do all that stuff with one line using np.where and argmax:

predictions = np.array([[1,-1,-1,-1],[-1,1,1,-1],[-1,-1,-1,1],[-1,-1,-1,-1]])

In [72]: np.where((predictions==1).sum(1)==1, (predictions==1).argmax(axis=1)+1, 0)
Out[72]: array([1, 0, 4, 0])

Upvotes: 2

Related Questions