u3728666
u3728666

Reputation: 119

Understanding behavior of index_put in PyTorch

I am trying to understand the behavior of index_put in PyTorch, but the document is not clear to me.

Given

a = torch.zeros(2, 3)
a.index_put([torch.tensor(1, 0), torch.tensor([1, 1])], torch.tensor(1.))

it returns

tensor([[1., 1., 0.], 
       [0., 0., 0.])

While given

a = torch.zeros(2, 3)
a.index_put([torch.tensor(0, 0), torch.tensor([1, 1])], torch.tensor(1.))

it returns

tensor([[0., 1., 0.], 
       [0., 0., 0.])

I am wondering what the rule of index_put on earth? What if I want to put three values to a, such that it returns

tensor([0., 1., 1.,],
       [0., 1., 0.])

Any help is appreciated!

Upvotes: 4

Views: 5013

Answers (1)

Poe Dator
Poe Dator

Reputation: 4913

I copied your examples here with argument names inserted, fixed brackets and correct output(yours was swapped):

a.index_put(indices=[torch.tensor([1, 0]), torch.tensor([1, 1])], values=torch.tensor(1.))

tensor([[0., 1., 0.],
        [0., 1., 0.]])

a.index_put(indices=[torch.tensor([0, 0]), torch.tensor([0, 1])], values = torch.tensor(1.))

tensor([[1., 1., 0.],
        [0., 0., 0.]]

What this method does is inserting value(s) into locations in the original a tensor indicated by indices. indices is a list of x coordinates of insertions and y coordinates of insertions. values may be single value or a 1d tensor.

to obtain the desired output use:

a.index_put(indices=[torch.tensor([0,0,1]), torch.tensor([1, 2, 1])], values=torch.tensor(1.))

tensor([[0., 1., 1.],
        [0., 1., 0.]])

moreover, you can pass multiple values in values argument to insert them into the indicated positions:

a.index_put(indices=[torch.tensor([0,0,1]), torch.tensor([1, 2, 1])], values=torch.tensor([1., 2., 3.]))

tensor([[0., 1., 2.],
        [0., 3., 0.]])

Upvotes: 3

Related Questions