Minh-Long Luu
Minh-Long Luu

Reputation: 2731

PyTorch: compare three tensors?

I have three boolean mask tensors that I want to create a boolean mask that if the value matches in three tensors then it is 1, else 0.

I tried torch.where(A == B == C, 1, 0), but it doesn't seem to support such.

Upvotes: 3

Views: 879

Answers (3)

iacob
iacob

Reputation: 24181

The torch.eq operator only supports binary tensor comparisons, hence you need to perform two comparisons:

(A==B) & (B==C)

Upvotes: 1

Cynichniy Bandera
Cynichniy Bandera

Reputation: 6103

AFAIK, the tensor is basically a NumPy array bound to the device. If not too expensive for your application and you can afford to do it on CPU, you can simply convert it to NumPy and do what you need with the comparison.

Upvotes: 0

GoodDeeds
GoodDeeds

Reputation: 8507

You can use:

((A == B) & (B == C))

If required, you can always convert the boolean tensor to an appropriate type:

((A == B) & (B == C)).to(float)

Upvotes: 0

Related Questions