Ahmad
Ahmad

Reputation: 9658

Torch: Update tensor with non-zero elements

suppose I have:

>>> a = torch.tensor([1, 2, 3, 0, 0, 1])
>>> b = torch.tensor([0, 1, 3, 3, 0, 0])

I want to update b with elements in a if it's not zero. How can I beneficently do that?

Expected:

>>> b = torch.tensor([1, 2, 3, 3, 0, 1])

Upvotes: 0

Views: 131

Answers (2)

A.Mounir
A.Mounir

Reputation: 588

To add to the previous answer and for more simplicity you can do it by one line of code:

b = torch.where(a!=0,a, b)

Output:

tensor([1, 2, 3, 3, 0, 1])

Upvotes: 1

DerekG
DerekG

Reputation: 3938

torch.where is your answer. I assume based on your example that you also want to replace only elements in a that are 0.

mask = torch.logical_and(b!=0,a==0)
output = torch.where(mask,b,a)

Upvotes: 1

Related Questions