Reputation: 2741
I have two binary masks of shape (batch_size, width, heigh)
that I want to create a binary mask which indicates the union of elements between the two.
To find the intersection, I can use torch.where(A == B, 1, 0)
, but how can I find the union?
Upvotes: 1
Views: 2295
Reputation: 114876
When working with binary masks, you should use logical operations such as:
logical_or()
, logical_and()
.
The intersection is then the binary mask:
intersection = A.logical_and(B)
and the union is:
union = A.logical_or(B)
BTW,
I'll leave it to you as an exercise to check why the intersection you computed (A == B
) is not correct.
Upvotes: 1