Valentin Macé
Valentin Macé

Reputation: 1272

Index a torch tensor with an array

I have the following torch tensor:

tensor([[-0.2,  0.3],
    [-0.5,  0.1],
    [-0.4,  0.2]])

and the following numpy array: (I can convert it to something else if necessary)

[1 0 1]

I want to get the following tensor:

tensor([0.3, -0.5, 0.2])

i.e. I want the numpy array to index each sub-element of my tensor. Preferably without using a loop.

Thanks in advance

Upvotes: 7

Views: 10472

Answers (2)

Dishin H Goyani
Dishin H Goyani

Reputation: 7693

You may want to use torch.gather - "Gathers values along an axis specified by dim."

t = torch.tensor([[-0.2,  0.3],
    [-0.5,  0.1],
    [-0.4,  0.2]])
idxs = np.array([1,0,1])

idxs = torch.from_numpy(idxs).long().unsqueeze(1)  
# or   torch.from_numpy(idxs).long().view(-1,1)
t.gather(1, idxs)
tensor([[ 0.3000],
        [-0.5000],
        [ 0.2000]])

Here, your index is numpy array so you have to convert it to LongTensor.

Upvotes: 4

Zabir Al Nazi Nabil
Zabir Al Nazi Nabil

Reputation: 11198

Just simply, use a range(len(index)) for the first dimension.

import torch

a = torch.tensor([[-0.2,  0.3],
    [-0.5,  0.1],
    [-0.4,  0.2]])

c = [1, 0, 1]


b = a[range(3),c]

print(b)

Upvotes: 3

Related Questions