Chris
Chris

Reputation: 1306

How can I select single indices over a dimension in pytorch?

Assume I have a tensor sequences of shape [8, 12, 2]. Now I would like to make a selection of that tensor for each first dimension which results in a tensor of shape [8, 2]. The selection over dimension 1 is specified by indices stored in a long tensor indices of shape [8].

I tried this, however it selects each index in indices for each first dimension in sequences instead of only one.

sequences[:, indices]

How can I make this query without a slow and ugly for loop?

Upvotes: 13

Views: 15192

Answers (4)

Arun
Arun

Reputation: 164

It can be done using torch.Tensor.gather

sequences = torch.randn(8,12,2)
# defining the indices as a 1D Tensor of random integers and reshaping it to use with Tensor.gather
indices = torch.randint(sequences.size(1),(sequences.size(0),)).unsqueeze(-1).unsqueeze(-1).repeat(1,1,sequences.size(-1))
# indices shape: (8, 1, 2)
output = sequences.gather(1,indices).squeeze(1)
# output shape: (8, 2)

Upvotes: 0

Lock-not-gimbal
Lock-not-gimbal

Reputation: 309

sequences[torch.arange(sequences.size(0)), indices]

Upvotes: 9

gunxueqiu
gunxueqiu

Reputation: 21

This should be doable by torch.gather, but you need to convert your index tensor first by

  • unsqueeze it to match the number of dimension of your input tensor
  • repeat_interleave it to match the size of last dimension

Here is an example based on your description:

# original indices dimension [8]
# after first unsueeze, dimension is [8, 1]
indices = torch.unsqueeze(indices, 1)

# after second unsueeze, dimension is [8, 1, 1]
indices = torch.unsqueeze(indices, 2)

# after repeat, dimension is [8, 1, 2]
indices = torch.repeat_interleave(indices, 2, dim=2)

# now you have the right dimension for torch.gather
# don't forget to squeeze the redundant dimension
# result has dimension [8, 2] 
result = torch.gather(sequences, 1, indices).squeeze()

Upvotes: 0

Niklas Höpner
Niklas Höpner

Reputation: 424

torch.index_select solves your problem more easily than torch.gather since you don't have to adapt the dimensions of the indeces. Indeces must be a tensor. For your case

indeces = [0,2]
a = torch.rand(size=(3,3,3))
torch.index_select(a,dim=1,index=torch.tensor(indeces,dtype=torch.long))

Upvotes: 0

Related Questions