Reputation: 483
I want to perform something relatively simple in numpy:
However I ended up with a pretty complex code:
predictions = np.array([[1,-1,-1,-1],[-1,1,1,-1],[-1,-1,-1,1],[-1,-1,-1,-1]])
one_count = (predictions == 1).sum(1)
valid_rows_idx = np.where(one_count==1)
result = np.zeros(predictions.shape[0])
for idx in valid_rows_idx:
result[idx] = np.where(predictions[idx,:]==1)[1] + 1
If I print result
, the program prints
[ 1. 0. 4. 0.]
which is the desired result.
I was wondering if there is a simpler way of writing the last line using numpy.
Upvotes: 3
Views: 1039
Reputation: 31672
I'm not sure is that better or not but you could try to use argmax
for that. Also you don't need to use for loop and np.where
to get valid indices:
predictions = np.array([[1,-1,-1,-1],[-1,1,1,-1],[-1,-1,-1,1],[-1,-1,-1,-1]])
idx = (predictions == 1).sum(1) == 1
result = np.zeros(predictions.shape[0])
result[idx] = (predictions[idx]==1).argmax(axis=1) + 1
In [55]: result
Out[55]: array([ 1., 0., 4., 0.])
Or you could do all that stuff with one line using np.where
and argmax
:
predictions = np.array([[1,-1,-1,-1],[-1,1,1,-1],[-1,-1,-1,1],[-1,-1,-1,-1]])
In [72]: np.where((predictions==1).sum(1)==1, (predictions==1).argmax(axis=1)+1, 0)
Out[72]: array([1, 0, 4, 0])
Upvotes: 2