Reputation: 1904
I have an index tensor of size (2, 3)
:
>>> index = torch.empty(6).random_(0,8).view(2,3)
tensor([[6., 3., 2.],
[3., 4., 7.]])
And a value tensor of size (2, 8)
:
>>> value = torch.zeros(2,8)
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.]])
I want to set the element in value
to 1
by the index along dim=-1
.** The output should be like:
>>> output
tensor([[0., 0., 1., 1., 0., 0., 1., 0.],
[0., 0., 0., 1., 1., 0., 0., 1.]])
I tried value[range(2), index] = 1
but it triggers an error. I also tried torch.index_fill
but it doesn't accept batched indices. torch.scatter
requires creating an extra tensor of size 2*8
full of 1
, which consumes unnecessary memory and time.
Upvotes: 2
Views: 989
Reputation: 40648
You can actually use torch.Tensor.scatter_
by setting the value
(int) option instead of the src
option (Tensor).
>>> value.scatter_(dim=-1, index=index.long(), value=1)
>>> value
tensor([[0., 0., 1., 1., 0., 0., 1., 0.],
[0., 0., 0., 1., 1., 0., 0., 1.]])
Make sure the index
is of type int64 though.
Upvotes: 3