Duane
Duane

Reputation: 5140

Selecting second dim of tensor using an index tensor

I have a 2D tensor and an index tensor. The 2D tensor has a batch dimension, and a dimension with 3 values. I have an index tensor that selects exactly 1 element of the 3 values. What is the "best" way to product a slice containing just the elements in the index tensor?

t = torch.tensor([[1,2,3], [4,5,6], [7,8,9]])
t = tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])

i = torch.tensor([0,0,1], dtype=torch.int64)
tensor([0, 0, 1])

Expected output...

tensor([1, 4, 8])

Upvotes: 0

Views: 1259

Answers (1)

Keiku
Keiku

Reputation: 8803

An example of the answer is as follows.

import torch

t = torch.tensor([[1,2,3], [4,5,6], [7,8,9]])
col_i = [0, 0, 1]
row_i = range(3)
print(t[row_i, col_i])
# tensor([1, 4, 8])

Upvotes: 2

Related Questions