Noob Saibot
Noob Saibot

Reputation: 4749

Numpy: np.where function given argmax

I'm trying to perform this function on two 2D numpy arrays: Step 1: Find np.argmax(b, axis=1) indices. Step 2: Find b[indices] > a[indices] Step 3: Return value in a 2D Boolean array.

I tried this:

np.where((b>a)&np.argmax(b,axis=1).reshape((3,-1)), True, False)

but no dice. Any ideas?

Thanks in advance.

Upvotes: 1

Views: 2598

Answers (1)

YXD
YXD

Reputation: 32511

Based on your comments my best understanding is:

output = (np.max(b,axis=1)[...,None] == b) & (b > a)

Where we make use of Numpy broadcasting to do the "is the maximum of its row in b" part:

np.max(b,axis=1)[...,None] == b

Or perhaps clearer:

np.max(b,axis=1)[...,np.newaxis] == b

Upvotes: 4

Related Questions