Reputation: 1306
Assume I have a tensor sequences
of shape [8, 12, 2]
. Now I would like to make a selection of that tensor for each first dimension which results in a tensor of shape [8, 2]
. The selection over dimension 1 is specified by indices stored in a long tensor indices
of shape [8]
.
I tried this, however it selects each index in indices
for each first dimension in sequences
instead of only one.
sequences[:, indices]
How can I make this query without a slow and ugly for
loop?
Upvotes: 13
Views: 15192
Reputation: 164
It can be done using torch.Tensor.gather
sequences = torch.randn(8,12,2)
# defining the indices as a 1D Tensor of random integers and reshaping it to use with Tensor.gather
indices = torch.randint(sequences.size(1),(sequences.size(0),)).unsqueeze(-1).unsqueeze(-1).repeat(1,1,sequences.size(-1))
# indices shape: (8, 1, 2)
output = sequences.gather(1,indices).squeeze(1)
# output shape: (8, 2)
Upvotes: 0
Reputation: 21
This should be doable by torch.gather
, but you need to convert your index tensor first by
unsqueeze
it to match the number of dimension of your input tensorrepeat_interleave
it to match the size of last dimensionHere is an example based on your description:
# original indices dimension [8]
# after first unsueeze, dimension is [8, 1]
indices = torch.unsqueeze(indices, 1)
# after second unsueeze, dimension is [8, 1, 1]
indices = torch.unsqueeze(indices, 2)
# after repeat, dimension is [8, 1, 2]
indices = torch.repeat_interleave(indices, 2, dim=2)
# now you have the right dimension for torch.gather
# don't forget to squeeze the redundant dimension
# result has dimension [8, 2]
result = torch.gather(sequences, 1, indices).squeeze()
Upvotes: 0
Reputation: 424
torch.index_select solves your problem more easily than torch.gather since you don't have to adapt the dimensions of the indeces. Indeces must be a tensor. For your case
indeces = [0,2]
a = torch.rand(size=(3,3,3))
torch.index_select(a,dim=1,index=torch.tensor(indeces,dtype=torch.long))
Upvotes: 0