LOST
LOST

Reputation: 3250

PyTorch index in a batch

Given tensor IN of shape (A, B, C, D) and index tensor IDX of shape [A, B, C] with torch.long values in [0, C), how can I get a tensor OUT of shape (A, B, C, D) such that:

OUT[a, b, c, :] == IN[a, b, IDX[a, b, c], :]

This is trivial without dimensions A and B:

# C = 2, D = 3
IN = torch.arange(6).view(2, 3)
IDX = torch.tensor([0,0])

print(IN[IDX])
# tensor([[0, 1, 2],
#         [0, 1, 2]])

Obviously, I can write a nested for loop over A and B. But surely there must be a vectorized way to do it?

Upvotes: 1

Views: 1071

Answers (1)

Ivan
Ivan

Reputation: 40618

This is the perfect use case for torch.gather. Given two 4d tensors, input the input tensor and index the tensor containing the indices for input, calling torch.gather on dim=2 will return a tensor out shaped like input such that:

out[i][j][k][l] = input[i][j][index[i][j][k][l]][l]

In other words, index indexes dimension n°3 of input.

Before applying such function though, notice all tensors must have the same number of dimensions. Since index is only 3d, we need to insert and expand an additional 4th dimension on it. We can do so with the following lines:

>>> idx_ = idx[...,None].expand_as(x)

Then call the torch.gather function

>>> x.gather(dim=2, index=idx_)

You can try out the solution with this code:

>>> A = 1; B = 2; C=3; D=2

>>> x = torch.rand(A,B,C,D)
tensor([[[[0.6490, 0.7670],
          [0.7847, 0.9058],
          [0.3606, 0.7843]],

         [[0.0666, 0.7306],
          [0.1923, 0.3513],
          [0.5287, 0.3680]]]])

>>> idx = torch.randint(0, C, (A,B,C))
tensor([[[1, 2, 2],
         [0, 0, 1]]])

>>> x.gather(dim=2, index=idx[...,None].expand_as(x))
tensor([[[[0.7847, 0.9058],
          [0.3606, 0.7843],
          [0.3606, 0.7843]],

         [[0.0666, 0.7306],
          [0.0666, 0.7306],
          [0.1923, 0.3513]]]])

Upvotes: 1

Related Questions