Reputation: 119
I am trying to understand the behavior of index_put
in PyTorch, but the document is not clear to me.
Given
a = torch.zeros(2, 3)
a.index_put([torch.tensor(1, 0), torch.tensor([1, 1])], torch.tensor(1.))
it returns
tensor([[1., 1., 0.],
[0., 0., 0.])
While given
a = torch.zeros(2, 3)
a.index_put([torch.tensor(0, 0), torch.tensor([1, 1])], torch.tensor(1.))
it returns
tensor([[0., 1., 0.],
[0., 0., 0.])
I am wondering what the rule of index_put
on earth? What if I want to put three values to a, such that it returns
tensor([0., 1., 1.,],
[0., 1., 0.])
Any help is appreciated!
Upvotes: 4
Views: 5013
Reputation: 4913
I copied your examples here with argument names inserted, fixed brackets and correct output(yours was swapped):
a.index_put(indices=[torch.tensor([1, 0]), torch.tensor([1, 1])], values=torch.tensor(1.))
tensor([[0., 1., 0.],
[0., 1., 0.]])
a.index_put(indices=[torch.tensor([0, 0]), torch.tensor([0, 1])], values = torch.tensor(1.))
tensor([[1., 1., 0.],
[0., 0., 0.]]
What this method does is inserting value(s) into locations in the original a
tensor indicated by indices
. indices is a list of x coordinates of insertions and y coordinates of insertions. values may be single value or a 1d tensor.
to obtain the desired output use:
a.index_put(indices=[torch.tensor([0,0,1]), torch.tensor([1, 2, 1])], values=torch.tensor(1.))
tensor([[0., 1., 1.],
[0., 1., 0.]])
moreover, you can pass multiple values in values
argument to insert them into the indicated positions:
a.index_put(indices=[torch.tensor([0,0,1]), torch.tensor([1, 2, 1])], values=torch.tensor([1., 2., 3.]))
tensor([[0., 1., 2.],
[0., 3., 0.]])
Upvotes: 3