K. L.
K. L.

Reputation: 674

Where clause with numpy

Here's my array:

a = [[0.,0.,0.1,0.2], [0.,0.3,0.4,0.3], [0.,0.,0.1,0.]]

I would like to do a where clause which will return the indices of the elements in 'a' where the sum of the values for this element is equal to 1. Something like : where(sum(a) == 1)

Can someone guide me ?

Thanks.

Upvotes: 1

Views: 1061

Answers (2)

JoshAdel
JoshAdel

Reputation: 68702

In [1]: import numpy as np

In [2]: a = np.array([[0.,0.,0.1,0.2], [0.,0.3,0.4,0.3], [0.,0.,0.1,0.]])

In [3]: a
Out[3]:
array([[ 0. ,  0. ,  0.1,  0.2],
       [ 0. ,  0.3,  0.4,  0.3],
       [ 0. ,  0. ,  0.1,  0. ]])

In [4]: np.where(np.sum(a,axis=1) == 1)
Out[4]: (array([1]),)

So the sum of the 2nd row (index == 1) is 1.0. np.sum(a, axis=1) takes the sums across the rows, which are equivalent to the elements of your original list of lists. Without specifying an explicit axis, numpy takes the sum of all elements of the array. Note, there is a difference between the python builtin sum and np.sum. This is a good reason not to do from numpy import * and keep things explicit.

Update:

As @Jaime suggested doing a comparison with the equality is not safe. Ideally np.allclose would have an axis option, but it doesn't. You can still recreate this using:

np.where(np.abs(np.sum(a,1) - 1.0) <= 1E-5)

See the docs for np.allclose for more info.

Upvotes: 8

falsetru
falsetru

Reputation: 369274

Using enumerate, list comprehension:

>>> a = [[0.,0.,0.1,0.2], [0.,0.3,0.4,0.3], [0.,0.,0.1,0.]]
>>> [i for i, xs in enumerate(a) if sum(xs) == 1]
[1]

Upvotes: 2

Related Questions