cometta
cometta

Reputation: 35689

select sub elements from another batch

I have first tensor with size torch.Size([12, 64, 8, 8, 3]) , (8,8,3) is the image size, 64 is the patches, 12 is the batch size

There is another tensor with size torch.Size([12, 10]) which select 10 patches for each item in the batch (select 10 patches from total of 64). Thus it store the indexes. How to use this to query for the first tensor with list comprehension?

Upvotes: 0

Views: 751

Answers (2)

Madhoolika
Madhoolika

Reputation: 444

Aim is to query the 1st patch (out of the 10 selected patches) of each batch. When iterating over b, we get the list of selected patches indices. Choose the 1st one from them by index 0. As they are tensors, convert the type to int so as to index into the tensor of images and retrieve the respective patch for each batch.

a = torch.rand(12, 64, 8, 8, 3) # generating 12 batches, with 64 patches,each of size 8x8x3
b = torch.randint(64, (12, 10)) # choosing 10 patches (within the 64), for each of the 12 batches

first_tensors = [a[batch, int(patches[0])] for batch, patches in zip(range(12), b)]

For sake of clarity, the below list comprehension would give the indices of the 1st patch of each batch.

[[batch, int(patches[0])] for batch, patches in zip(range(12), b)]
[[0, 40],
 [1, 27],
 [2, 17],
 [3, 62],
 [4, 9],
 [5, 51],
 [6, 32],
 [7, 38],
 [8, 63],
 [9, 10],
 [10, 2],
 [11, 6]]

Indexing the Tensor of images a with each pair of indices in the above list will give the corresponding patch.

Upvotes: 1

Szymon Maszke
Szymon Maszke

Reputation: 24691

You can use index_select:

c = [torch.index_select(i, dim=0, index=j) for i, j in zip(a,b)]

a and b are your tensor and indices respectively.

You could stack it in the zero dimension afterwards.

Upvotes: 0

Related Questions