Haha TTpro
Haha TTpro

Reputation: 5566

Indexing second dimension of Tensor using indices

I selected element in my tensor using a tensor of indices. Here the code below I use list of indices 0, 3, 2, 1 to select 11, 15, 2, 5

>>> import torch
>>> a = torch.Tensor([5,2,11, 15])
>>> torch.randperm(4)

 0
 3
 2
 1
[torch.LongTensor of size 4]

>>> i = torch.randperm(4)
>>> a[i]

 11
 15
  2
  5
[torch.FloatTensor of size 4]

Now, I have

>>> b = torch.Tensor([[5, 2, 11, 15],[5, 2, 11, 15], [5, 2, 11, 15]])
>>> b

  5   2  11  15
  5   2  11  15
  5   2  11  15
[torch.FloatTensor of size 3x4]

Now, I want to use indices to select column 0, 3, 2, 1. In others word, I want a tensor like this

>>> b

 11  15   2   5
 11  15   2   5
 11  15   2   5
[torch.FloatTensor of size 3x4]

Upvotes: 3

Views: 11571

Answers (1)

entrophy
entrophy

Reputation: 2125

If using pytorch version v0.1.12

For this version there isnt an easy way to do this. Even though pytorch promises tensor manipulation to be exactly like numpy's, there are some capabilities that are still lacking. This is one of them.

Typically you would be able to do this relatively easily if you were working with numpy arrays. Like so.

>>> i = [2, 1, 0, 3]
>>> a = np.array([[5, 2, 11, 15],[5, 2, 11, 15], [5, 2, 11, 15]])
>>> a[:, i]

array([[11,  2,  5, 15],
       [11,  2,  5, 15],
       [11,  2,  5, 15]])

But the same thing with Tensors will give you an error:

>>> i = torch.LongTensor([2, 1, 0, 3])
>>> a = torch.Tensor([[5, 2, 11, 15],[5, 2, 11, 15], [5, 2, 11, 15]])
>>> a[:,i]

The error:

TypeError: indexing a tensor with an object of type torch.LongTensor. The only supported types are integers, slices, numpy scalars and torch.LongTensor or torch.ByteTensor as the only argument.

What that TypeError is telling you is, if you plan to use a LongTensor or a ByteTensor for indexing, then the only valid syntax is a[<LongTensor>] or a[<ByteTensor>]. Anything other than that will not work.

Because of this limitation, you have two options:

Option 1: Convert to numpy, permute, then back to Tensor

>>> i = [2, 1, 0, 3]
>>> a = torch.Tensor([[5, 2, 11, 15],[5, 2, 11, 15], [5, 2, 11, 15]])
>>> np_a = a.numpy()
>>> np_a = np_a[:,i]
>>> a = torch.from_numpy(np_a)
>>> a

 11   2   5  15
 11   2   5  15
 11   2   5  15
[torch.FloatTensor of size 3x4]

Option 2: Move the dim you want to permute to 0 and then do it

you will move the dim that you are looking to permute, (in your case dim=1) to 0, perform the permutation, and move it back. Its a bit hacky, but it gets the job done.

def hacky_permute(a, i, dim):
    a = torch.transpose(a, 0, dim)
    a = a[i]
    a = torch.transpose(a, 0, dim)
    return a

And use it like so:

>>> i = torch.LongTensor([2, 1, 0, 3])
>>> a = torch.Tensor([[5, 2, 11, 15],[5, 2, 11, 15], [5, 2, 11, 15]])
>>> a = hacky_permute(a, i, dim=1)
>>> a

 11   2   5  15
 11   2   5  15
 11   2   5  15
[torch.FloatTensor of size 3x4]

If using pytorch version v0.2.0

Direct indexing using a tensor now works in this version. ie.

>>> i = torch.LongTensor([2, 1, 0, 3])
>>> a = torch.Tensor([[5, 2, 11, 15],[5, 2, 11, 15], [5, 2, 11, 15]])
>>> a[:,i]

 11   2   5  15
 11   2   5  15
 11   2   5  15
[torch.FloatTensor of size 3x4]

Upvotes: 5

Related Questions