Reputation: 5646
I have a 2D tensor and a 1D tensor:
import torch
torch.manual_seed(0)
out = torch.randn((16,2))
target = torch.tensor([0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0])
For each row of out
, I want to select the corresponding column as indexed by target
. Thus, my output will be a (16,1)
tensor. I tried the solution mentioned here:
https://stackoverflow.com/a/58937071
But I get:
Traceback (most recent call last):
File "/opt/conda/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3369, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-7-50d103c3b56c>", line 1, in <cell line: 1>
out.gather(1, target)
RuntimeError: Index tensor must have the same number of dimensions as input tensor
Can you help?
Upvotes: 0
Views: 734
Reputation: 40648
In order to apply torch.gather
, the two tensors must have the same number of dimensions. As such you should unsqueeze an additional dimension on target
in last position:
>>> out.gather(1, target[:,None])
tensor([[-1.1258],
[-0.4339],
[ 0.6920],
[-2.1152],
[ 0.3223],
[ 0.3500],
[ 1.2377],
[ 1.1168],
[-1.6959],
[ 0.7935],
[ 0.5988],
[-0.3414],
[ 0.7502],
[ 0.1835],
[ 1.5863],
[ 0.9463]])
Upvotes: 3