Amateur1
Amateur1

Reputation: 67

Error when using the numpy.where function

I want to use the numpy.where function to check whether an element in an array is a certain string, like for example coffee and then returning a certain vector in places where this is true, and a different one in places where this is not the case.

However, I keep getting the error message saying operands could not be broadcast together with shapes (4,) (1,3) (1,3).

Is there some other way I can do this without using for loops too much (the question explicitly says i should not use them)?

lst_1 = np.array(["dog", "dog1", "dog2", "dog3"])
a = np.where(lst_1 == "dog", [[1,0,0]], [[0,0,0]])
print(a)

Upvotes: 2

Views: 152

Answers (2)

Robin Gertenbach
Robin Gertenbach

Reputation: 10806

Can be done as a one-liner:

out = np.array([[0,0,0], [1,0,0]])
idx = lst_1 == dog
out[idx.astype(np.int32)]

Alternatively avoiding casting:

np.take([[0,0,0],[1,0,0]], lst_1 == "dog", axis=0)

Upvotes: 1

caaax
caaax

Reputation: 460

If you want to do this without for loops, you can make use of lambda functions:

lst_1 = np.array(["dog", "dog1", "dog2", "dog3"])
a = list(map(lambda x: [1,0,0] if x=='dog' else [0,0,0], lst_1))

print(a)

> [[1, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]]

Upvotes: 0

Related Questions