ihdv
ihdv

Reputation: 2307

In pytorch, is there a built-in method to extract rows with given indexes?

Suppose I have a torch tensor

import torch
a = torch.tensor([[1,2,3],
                  [4,5,6],
                  [7,8,9]])

and a list

b = [0,2]

Is there a built-in method to extract the rows 0 and 2 and put them in a new tensor:

tensor([[1,2,3],
        [7,8,9]])

In particular, is there a function that look likes this:

extract_rows(a,b) -> c

where c contains desired rows. Sure, this can done by a for loop, but a built-in method is in general faster.

Note that the example is only an example, there could be dozens of indexes in the list, and hundreds of rows in the tensor.

Upvotes: 1

Views: 1067

Answers (2)

Dishin H Goyani
Dishin H Goyani

Reputation: 7723

Simply a[b] would work

import torch
a = torch.tensor([[1,2,3],
                  [4,5,6],
                  [7,8,9]])
b = [0,2]
a[b]
tensor([[1, 2, 3],
        [7, 8, 9]])

Upvotes: 0

Sheri
Sheri

Reputation: 1415

have a look at torch builtin index_select() method. It would be helpful to you. or You can do this using slicing.

tensor = [[1,2,3],
            [4,5,6],
            [7,8,9]]

new_tensor = tensor[0::2]
print(new_tensor)

Output:

[[1, 2, 3], [7, 8, 9]]

Upvotes: 1

Related Questions