Bedir Yilmaz
Bedir Yilmaz

Reputation: 4083

Slicing a 4D tensor with a 3D tensor-index in PyTorch

I have a 4D tensor (which happens to be a stack of three batches of 56x56 images where each batch has 16 images) with the size of [16, 3, 56, 56]. My goal is to select the correct one of those three batches (with my index map that has the size of [16, 56, 56]) for each pixel and get the images that I want.

Now, I want to select the particular batches of images inside those three batches, with a which has values such as

       [[[ 0,  0,  2,  ...,  0,  0,  0],
         [ 0,  0,  2,  ...,  0,  0,  0],
         [ 0,  0,  0,  ...,  0,  0,  0],
         ...,
         [ 0,  0,  0,  ...,  0,  0,  0],
         [ 0,  2,  0,  ...,  0,  0,  0],
         [ 0,  2,  2,  ...,  0,  0,  0]],

        [[ 0,  2,  0,  ...,  1,  1,  0],
         [ 0,  2,  0,  ...,  0,  0,  0],
         [ 0,  0,  0,  ...,  0,  2,  0],
         ...,
         [ 0,  0,  0,  ...,  0,  2,  0],
         [ 0,  0,  2,  ...,  0,  2,  0],
         [ 0,  0,  2,  ...,  0,  0,  0]]]

So for the 0s, the value will be selected from the first batch, where 1 and 2 will mean I want to select the values from the second and the third batch.

Here are some of the visualizations of the indices, each color denoting another batch.

enter image description here

enter image description here

I have tried to transpose the 4D tensor to match the dimensions of my indices, but it did not work. All it does is to give me a copy of the dimensions I have tried to select. Means

tposed = torch.transpose(fourD, 0,1) print(indices.size(),
outs.size(), tposed[:, indices].size())

outputs

torch.Size([16, 56, 56]) torch.Size([16, 3, 56, 56]) torch.Size([3, 16, 56, 56, 56, 56])

while the shape I need is

torch.Size([16, 56, 56]) or torch.Size([16, 1, 56, 56])

and as an example, if I try to select the right values for only the first image on the batch with

fourD[0,indices].size()

I get a shape like

torch.Size([16, 56, 56, 56, 56])

Not to mention that I get an out of memory error when I try this on the whole tensor.

I appreciate any help for using these indices to select either one of these three batches for each pixel in my images.

Note :

I have tried the option

outs[indices[:,None,:,:]].size()

and that returns

torch.Size([16, 1, 56, 56, 3, 56, 56])

Edit : torch.take does not help much since it treats the input tensor as a single dimensional array.

Upvotes: 1

Views: 3333

Answers (1)

Bedir Yilmaz
Bedir Yilmaz

Reputation: 4083

Turns out there is a function in PyTorch that has the functionality I was searching for.

torch.gather(fourD, 1, indices.unsqueeze(1)) 

did the job.

Here is a beautiful explanation of what gather does.

Upvotes: 2

Related Questions