Reputation: 331
Suppose I have two index tensors and an image tensor, how can I sample the (x, y) points from the image?
img.shape # -> (batch x H x W x 3)
x.shape # -> (batch x H x W)
y.shape # -> batch x H x W)
(H x W being height x width) Basically I want to perform something like a batch "shuffle" of the image pixel intensities.
Upvotes: 1
Views: 290
Reputation: 13103
I am assuming you want output[a, b, c, d] == img[a, x[a, b, c], y[a, b, c], d]
, where a, b, c, d are variables which iterate over batch, H, W and 3, respectively. You can solve that by applying torch.gather
twice. As you can see in documentation it performs a similar indexing operation for a single dimension, so we would first gather on dim 1 with x
as the index
parameter and again on dim 2 with y
. Unfortunately gather
does not broadcast, so to deal with the trailing rgb dimension we have to add an extra dimension and manually repeat it. The code looks like this
import torch
# prepare data as in the example
batch, H, W = 2, 4, 5
img = torch.arange(batch * H * W * 3).reshape(batch, H, W, 3)
x = torch.randint(0, H, (batch, H, W))
y = torch.randint(0, W, (batch, H, W))
# deal with `torch.gather` not broadcasting
x = x.unsqueeze(3).repeat(1, 1, 1, 3)
y = y.unsqueeze(3).repeat(1, 1, 1, 3)
# do the actual indexing
x_shuff = torch.gather(img, dim=1, index=x)
output = torch.gather(x_shuff, dim=2, index=y)
Upvotes: 2