roger
roger

Reputation: 9893

How to generate new tensor by given indexes and tensor in pytorch?

I have a tensor like this:

x = torch.tensor([[3, 4, 2], [0, 1, 5]])

and I have a indexes like this:

ind = torch.tensor([[1, 1, 0], [0, 0, 1]])

then I want to generate a new tensor by x and ind:

z = torch.tensor([0, 1, 2], [3, 4, 5])

I impement it with python like this:

# -*- coding: utf-8 -*-

import torch

x = torch.tensor([[3, 4, 2], [0, 1, 5]])
ind = torch.tensor([[1, 1, 0], [0, 0, 1]])

z = torch.zeros_like(x)
for i in range(x.shape[0]):
    for j in range(x.shape[1]):
        z[i, j] = x[ind[i][j]][j]

print(z)

I want to know how to solve this by pytorch?

Upvotes: 1

Views: 91

Answers (1)

Shai
Shai

Reputation: 114786

You are looking for torch.gather

In [1]: import torch

In [2]: x = torch.tensor([[3, 4, 2], [0, 1, 5]])

In [3]: ind = torch.tensor([[1, 1, 0], [0, 0, 1]])

In [4]: torch.gather(x, 0, ind)
Out[4]:
tensor([[0, 1, 2],
        [3, 4, 5]])

Upvotes: 1

Related Questions