gavra
gavra

Reputation: 93

Vectorized way to apply a 3-dimension mask to RGB in pytorch

I have a HxWx3 tensor representing an RGB image and a HxWx3 mask (boolean) tensor as input. It is assumed that for each (i,j) in the mask tensor there's exactly one true value (that is exactly one of R\G\B is on). I want to apply the mask to the image to result in a HxW (or HxWx1) tensor V where V[i,j]='the matching R\G\B value according to the mask'.

Using Problem applying binary mask to an RGB image with numpy I was able to achieve the following:

>>> X*mask
tensor([[[ 9., 10.],
         [ 0.,  0.]],

        [[ 0.,  0.],
         [ 0., 20.]],

        [[ 0.,  0.],
         [30.,  0.]]])

But as stated, I want a single dim HxW and not HxWx3 as result.

Illustration: enter image description here

Upvotes: 2

Views: 270

Answers (1)

Mercury
Mercury

Reputation: 4181

Assuming that for each i,j only a single R/G/B value is retained, you can simply do:

(X*mask).sum(axis=2)

This should give you your desired (HxW) output.

Upvotes: 4

Related Questions