Reputation: 5140
I have a 2D tensor and an index tensor. The 2D tensor has a batch dimension, and a dimension with 3 values. I have an index tensor that selects exactly 1 element of the 3 values. What is the "best" way to product a slice containing just the elements in the index tensor?
t = torch.tensor([[1,2,3], [4,5,6], [7,8,9]])
t = tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
i = torch.tensor([0,0,1], dtype=torch.int64)
tensor([0, 0, 1])
Expected output...
tensor([1, 4, 8])
Upvotes: 0
Views: 1259
Reputation: 8803
An example of the answer is as follows.
import torch
t = torch.tensor([[1,2,3], [4,5,6], [7,8,9]])
col_i = [0, 0, 1]
row_i = range(3)
print(t[row_i, col_i])
# tensor([1, 4, 8])
Upvotes: 2