Antonio Paladini
Antonio Paladini

Reputation: 83

Numpy select over second axis

I know this is supposed to be simple but I can't figure it out.

The problem:

gt_prices = np.random.uniform(0, 100, size =  (121147, 28))
pred_idxs = np.random.randint(0, 28 , size = (121147,))
print(gt_prices.shape, pred_idxs.shape)
(121147, 28) (121147,)

I want to get an array of shape (121147,), where for each row I have the element of ground_truth_prices in the position given by pred_idxs. In other words, I want to do this:

selected_prices = np.array([gt_prices[i, pred_idxs[i]] for i in range(gt_prices.shape[0])])

But I'd like to do everything with NumPy. Is this possible?

Upvotes: 1

Views: 1053

Answers (2)

user1519665
user1519665

Reputation: 511

There is now an easy wrapper for this from numpy: https://numpy.org/devdocs/reference/generated/numpy.take_along_axis.html

For your usage, I believe it would be:

gt_prices = np.random.uniform(0, 100, size = (121147, 28))
pred_idxs = np.random.randint(0, 28 , size = (121147, 1)) # number of dimensions has to match
your_output = np.take_along_axis(gt_prices, pred_idxs, axis=1)  # output shape [121147, 1]

Upvotes: 0

Pani
Pani

Reputation: 1377

You can do the following (used a smaller dimension of 3 for checking the correctness easier)

gt_prices = np.random.uniform(0, 100, size =  (3, 28))
pred_idxs = np.random.randint(0, 28 , size = (3,))
indices = np.expand_dims(pred_idxs, axis=1)
gt_prices[np.arange(gt_prices.shape[0])[:,None], indices]

Upvotes: 1

Related Questions