Reputation: 35689
I have first tensor with size torch.Size([12, 64, 8, 8, 3])
, (8,8,3) is the image size, 64 is the patches, 12 is the batch size
There is another tensor with size torch.Size([12, 10])
which select 10 patches for each item in the batch (select 10 patches from total of 64). Thus it store the indexes. How to use this to query for the first tensor with list comprehension?
Upvotes: 0
Views: 751
Reputation: 444
Aim is to query the 1st patch (out of the 10 selected patches) of each batch. When iterating over b
, we get the list of selected patches indices. Choose the 1st one from them by index 0. As they are tensors, convert the type to int
so as to index into the tensor of images and retrieve the respective patch for each batch.
a = torch.rand(12, 64, 8, 8, 3) # generating 12 batches, with 64 patches,each of size 8x8x3
b = torch.randint(64, (12, 10)) # choosing 10 patches (within the 64), for each of the 12 batches
first_tensors = [a[batch, int(patches[0])] for batch, patches in zip(range(12), b)]
For sake of clarity, the below list comprehension would give the indices of the 1st patch of each batch.
[[batch, int(patches[0])] for batch, patches in zip(range(12), b)]
[[0, 40],
[1, 27],
[2, 17],
[3, 62],
[4, 9],
[5, 51],
[6, 32],
[7, 38],
[8, 63],
[9, 10],
[10, 2],
[11, 6]]
Indexing the Tensor of images a
with each pair of indices in the above list will give the corresponding patch.
Upvotes: 1
Reputation: 24691
You can use index_select
:
c = [torch.index_select(i, dim=0, index=j) for i, j in zip(a,b)]
a
and b
are your tensor and indices respectively.
You could stack
it in the zero dimension afterwards.
Upvotes: 0