sisaman
sisaman

Reputation: 83

How to change the values of a 2d tensor in certain rows and columns

Suppose I have an all-zero mask tensor like this:

mask = torch.zeros(5,3, dtype=torch.bool)

Now I want to set the value of mask at the intersection of the following rows and cols indices to True:

rows = torch.tensor([0,2,4]) 
cols = torch.tensor([1,2])

I would like to produce the following result:

tensor([[False, True,  True ],
        [False, False, False],
        [False, True,  True ],
        [False, False, False],
        [False, True,  True ]])

When I try the following code, I receive an error:

mask[rows, cols] = True

IndexError: shape mismatch: indexing tensors could not be broadcast together with shapes [3], [2]

How can I do that efficiently in PyTorch?

Upvotes: 2

Views: 2459

Answers (1)

Dishin H Goyani
Dishin H Goyani

Reputation: 7693

You need proper shape for that you can use torch.unsqueeze

mask = torch.zeros(5,3, dtype=torch.bool)
mask[rows, cols.unsqueeze(1)] = True
mask
tensor([[False,  True,  True],
        [False, False, False],
        [False,  True,  True],
        [False, False, False],
        [False,  True,  True]])

or torch.reshape

mask[rows, cols.reshape(-1,1)] = True
mask
tensor([[False,  True,  True],
        [False, False, False],
        [False,  True,  True],
        [False, False, False],
        [False,  True,  True]])

Upvotes: 3

Related Questions