Pratik K.
Pratik K.

Reputation: 331

Indexing a batched set of images

Suppose I have two index tensors and an image tensor, how can I sample the (x, y) points from the image?

img.shape # -> (batch x H x W x 3)
x.shape # -> (batch x H x W)
y.shape # -> batch x H x W)

(H x W being height x width) Basically I want to perform something like a batch "shuffle" of the image pixel intensities.

Upvotes: 1

Views: 290

Answers (1)

Jatentaki
Jatentaki

Reputation: 13103

I am assuming you want output[a, b, c, d] == img[a, x[a, b, c], y[a, b, c], d], where a, b, c, d are variables which iterate over batch, H, W and 3, respectively. You can solve that by applying torch.gather twice. As you can see in documentation it performs a similar indexing operation for a single dimension, so we would first gather on dim 1 with x as the index parameter and again on dim 2 with y. Unfortunately gather does not broadcast, so to deal with the trailing rgb dimension we have to add an extra dimension and manually repeat it. The code looks like this

import torch

# prepare data as in the example
batch, H, W = 2, 4, 5
img = torch.arange(batch * H * W * 3).reshape(batch, H, W, 3)
x = torch.randint(0, H, (batch, H, W))
y = torch.randint(0, W, (batch, H, W))

# deal with `torch.gather` not broadcasting
x = x.unsqueeze(3).repeat(1, 1, 1, 3)
y = y.unsqueeze(3).repeat(1, 1, 1, 3)

# do the actual indexing
x_shuff = torch.gather(img, dim=1, index=x)
output = torch.gather(x_shuff, dim=2, index=y)

Upvotes: 2

Related Questions