Rishab Borah
Rishab Borah

Reputation: 11

How to use argmax on a numpy array

I'm using Keras deep learning in Python. I have a numpy array that has this shape: [5, 30, 30, 30, 65] I want to perform an argmax along the 5th dimension (65), setting the index of the highest value in each 5th array to 1 and all others to 0. For context, the array is being outputted by a generator model, and I'm using the last dimension of length 65 as a one-hot encoding. How would I go about doing this?

Upvotes: 1

Views: 263

Answers (1)

Jacob K
Jacob K

Reputation: 784

To perform the argmax operation on the 5th dimension, you can specify the axis parameter using 0-based counting.

Therefore, to get the argmax for the 5th dimension you would use:
idx = np.argmax(arr, axis=4)[0, 0, 0, 0]

Then, to covert the 5th dimension array to a one-hot encoding, you can set that array to an np.zeros array and then updated the value at idx (from above) to be 1 as follows:

arr[a, b, c, d] = np.zeros((65,))
arr[a, b, c, d, idx] = 1

where a, b, c, and d are abritrary indices along the 1st through 4th dimensions.

Upvotes: 1

Related Questions