Reputation: 83
Suppose I have an all-zero mask tensor like this:
mask = torch.zeros(5,3, dtype=torch.bool)
Now I want to set the value of mask
at the intersection of the following rows
and cols
indices to True
:
rows = torch.tensor([0,2,4])
cols = torch.tensor([1,2])
I would like to produce the following result:
tensor([[False, True, True ],
[False, False, False],
[False, True, True ],
[False, False, False],
[False, True, True ]])
When I try the following code, I receive an error:
mask[rows, cols] = True
IndexError: shape mismatch: indexing tensors could not be broadcast together with shapes [3], [2]
How can I do that efficiently in PyTorch?
Upvotes: 2
Views: 2459
Reputation: 7693
You need proper shape for that you can use torch.unsqueeze
mask = torch.zeros(5,3, dtype=torch.bool)
mask[rows, cols.unsqueeze(1)] = True
mask
tensor([[False, True, True],
[False, False, False],
[False, True, True],
[False, False, False],
[False, True, True]])
mask[rows, cols.reshape(-1,1)] = True
mask
tensor([[False, True, True],
[False, False, False],
[False, True, True],
[False, False, False],
[False, True, True]])
Upvotes: 3