Reputation: 2307
Suppose I have a torch tensor
import torch
a = torch.tensor([[1,2,3],
[4,5,6],
[7,8,9]])
and a list
b = [0,2]
Is there a built-in method to extract the rows 0 and 2 and put them in a new tensor:
tensor([[1,2,3],
[7,8,9]])
In particular, is there a function that look likes this:
extract_rows(a,b) -> c
where c
contains desired rows. Sure, this can done by a for loop, but a built-in method is in general faster.
Note that the example is only an example, there could be dozens of indexes in the list, and hundreds of rows in the tensor.
Upvotes: 1
Views: 1067
Reputation: 7723
Simply a[b]
would work
import torch
a = torch.tensor([[1,2,3],
[4,5,6],
[7,8,9]])
b = [0,2]
a[b]
tensor([[1, 2, 3],
[7, 8, 9]])
Upvotes: 0
Reputation: 1415
have a look at torch builtin index_select() method. It would be helpful to you. or You can do this using slicing.
tensor = [[1,2,3],
[4,5,6],
[7,8,9]]
new_tensor = tensor[0::2]
print(new_tensor)
Output:
[[1, 2, 3], [7, 8, 9]]
Upvotes: 1