Reputation: 1447
Suppose I have a 2D tensor looking something like this:
[[44, 50, 1, 32],
.
.
.
[7, 13, 90, 83]]
and a list of row indices that I want to select that looks something like this [0, 34, 100, ..., 745]
. How can I go through and create a new tensor that contains only the rows whose indices are contained in the array?
Upvotes: 6
Views: 18836
Reputation: 378
You could select like with numpy
import torch
x = torch.Tensor([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 8, 7, 6],
[5, 4, 2, 1]])
indices = [0, 3]
print(x[indices])
# tensor([[1., 2., 3., 4.],
# [5., 4., 2., 1.]])
Upvotes: 11