Reputation: 9658
suppose I have:
>>> a = torch.tensor([1, 2, 3, 0, 0, 1])
>>> b = torch.tensor([0, 1, 3, 3, 0, 0])
I want to update b with elements in a if it's not zero. How can I beneficently do that?
Expected:
>>> b = torch.tensor([1, 2, 3, 3, 0, 1])
Upvotes: 0
Views: 131
Reputation: 588
To add to the previous answer and for more simplicity you can do it by one line of code:
b = torch.where(a!=0,a, b)
Output:
tensor([1, 2, 3, 3, 0, 1])
Upvotes: 1
Reputation: 3938
torch.where
is your answer. I assume based on your example that you also want to replace only elements in a
that are 0.
mask = torch.logical_and(b!=0,a==0)
output = torch.where(mask,b,a)
Upvotes: 1