namespace-Pt
namespace-Pt

Reputation: 1904

Batched index_fill in PyTorch

I have an index tensor of size (2, 3):

>>> index = torch.empty(6).random_(0,8).view(2,3)
tensor([[6., 3., 2.],
        [3., 4., 7.]])

And a value tensor of size (2, 8):

>>> value = torch.zeros(2,8)
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

I want to set the element in value to 1 by the index along dim=-1.** The output should be like:

>>> output
tensor([[0., 0., 1., 1., 0., 0., 1., 0.],
        [0., 0., 0., 1., 1., 0., 0., 1.]])

I tried value[range(2), index] = 1 but it triggers an error. I also tried torch.index_fill but it doesn't accept batched indices. torch.scatter requires creating an extra tensor of size 2*8 full of 1, which consumes unnecessary memory and time.

Upvotes: 2

Views: 989

Answers (1)

Ivan
Ivan

Reputation: 40648

You can actually use torch.Tensor.scatter_ by setting the value (int) option instead of the src option (Tensor).

>>> value.scatter_(dim=-1, index=index.long(), value=1)

>>> value
tensor([[0., 0., 1., 1., 0., 0., 1., 0.],
        [0., 0., 0., 1., 1., 0., 0., 1.]])

Make sure the index is of type int64 though.

Upvotes: 3

Related Questions