Reputation: 4083
I have a 4D tensor (which happens to be a stack of three batches of 56x56 images where each batch has 16 images) with the size of [16, 3, 56, 56]. My goal is to select the correct one of those three batches (with my index map that has the size of [16, 56, 56]) for each pixel and get the images that I want.
Now, I want to select the particular batches of images inside those three batches, with a which has values such as
[[[ 0, 0, 2, ..., 0, 0, 0],
[ 0, 0, 2, ..., 0, 0, 0],
[ 0, 0, 0, ..., 0, 0, 0],
...,
[ 0, 0, 0, ..., 0, 0, 0],
[ 0, 2, 0, ..., 0, 0, 0],
[ 0, 2, 2, ..., 0, 0, 0]],
[[ 0, 2, 0, ..., 1, 1, 0],
[ 0, 2, 0, ..., 0, 0, 0],
[ 0, 0, 0, ..., 0, 2, 0],
...,
[ 0, 0, 0, ..., 0, 2, 0],
[ 0, 0, 2, ..., 0, 2, 0],
[ 0, 0, 2, ..., 0, 0, 0]]]
So for the 0s, the value will be selected from the first batch, where 1 and 2 will mean I want to select the values from the second and the third batch.
Here are some of the visualizations of the indices, each color denoting another batch.
I have tried to transpose the 4D tensor to match the dimensions of my indices, but it did not work. All it does is to give me a copy of the dimensions I have tried to select. Means
tposed = torch.transpose(fourD, 0,1) print(indices.size(),
outs.size(), tposed[:, indices].size())
outputs
torch.Size([16, 56, 56]) torch.Size([16, 3, 56, 56]) torch.Size([3, 16, 56, 56, 56, 56])
while the shape I need is
torch.Size([16, 56, 56]) or torch.Size([16, 1, 56, 56])
and as an example, if I try to select the right values for only the first image on the batch with
fourD[0,indices].size()
I get a shape like
torch.Size([16, 56, 56, 56, 56])
Not to mention that I get an out of memory error when I try this on the whole tensor.
I appreciate any help for using these indices to select either one of these three batches for each pixel in my images.
Note :
I have tried the option
outs[indices[:,None,:,:]].size()
and that returns
torch.Size([16, 1, 56, 56, 3, 56, 56])
Edit : torch.take does not help much since it treats the input tensor as a single dimensional array.
Upvotes: 1
Views: 3333
Reputation: 4083
Turns out there is a function in PyTorch that has the functionality I was searching for.
torch.gather(fourD, 1, indices.unsqueeze(1))
did the job.
Here is a beautiful explanation of what gather does.
Upvotes: 2