user118967
user118967

Reputation: 5772

Indexing one element per row in PyTorch matrix

Let price be a PyTorch tensor with shape (num_days, num_products).

Let purchased_product_by_day be an integer tensor with shape (num_days,), with values in range(num_products).

Intuitively, price lists the price, for each day, of each product, and purchased_product_by_day are the indices of purchased products, one per day.

To obtain a tensor containing the expense per day, I can write

price[list(range(num_days)), purchased_product_by_day]

but this iterates over days at the Python level. I would like to obtain the same tensor on the C level.

I tried

price[:, purchased_product_by_day]

but that does not work, being equivalent to indexing each row by purchased_product_by_day for each day and stacking the results.

Is there a way to do that without iteration at the Python level?

Upvotes: 1

Views: 485

Answers (1)

user118967
user118967

Reputation: 5772

Based on Row-wise Element Indexing in PyTorch for C++, the solution in Python is

price.gather(1, purchased_product_by_day.unsqueeze(1)).squeeze()

Upvotes: 2

Related Questions