Shamoon
Shamoon

Reputation: 43569

How do I do masking in PyTorch / Numpy with different dimensions?

I have a mask with a size of torch.Size([20, 1, 199]) and a tensor, reconstruct_output and inputs both with a size of torch.Size([20, 1, 161, 199]).

I want to set reconstruct_output to inputs where the mask is 0. I tried:

reconstruct_output[mask == 0] = inputs[mask == 0]

But I get an error:

IndexError: The shape of the mask [20, 1, 199] at index 2 does not match the shape of the indexed tensor [20, 1, 161, 199] at index 2

Upvotes: 1

Views: 2460

Answers (1)

yatu
yatu

Reputation: 88275

We can use advanced indexing here. To obtain the indexing arrays which we want to use to index both reconstruct_output and inputs, we need the indices along its axes where m==0. For that we can use np.where, and use the resulting indices to update reconstruct_output as:

m = mask == 0
i, _, l = np.where(m)
reconstruct_output[i, ..., l] = inputs[i, ..., l]

Here's a small example which I've checked with:

mask = np.random.randint(0,3, (2, 1, 4))
reconstruct_output = np.random.randint(0,10, (2, 1, 3, 4))
inputs = np.random.randint(0,10, (2, 1, 3, 4))

Giving for instance:

print(reconstruct_output)

array([[[[8, 9, 7, 2],
         [5, 4, 6, 1],
         [1, 4, 0, 3]]],


       [[[4, 3, 3, 4],
         [0, 9, 9, 7],
         [3, 4, 9, 3]]]])

print(inputs)

array([[[[7, 3, 9, 8],
         [3, 1, 0, 8],
         [0, 5, 4, 8]]],


       [[[3, 7, 5, 8],
         [2, 5, 3, 8],
         [3, 6, 7, 5]]]])

And the mask:

print(mask)

array([[[0, 1, 2, 1]],

       [[1, 0, 1, 0]]])

By using np.where to find the indices where there are zeroes in mask we get:

m = mask == 0
i, _, l = np.where(m)

i
# array([0, 1, 1])

l
# array([0, 1, 3])

Hence we'll be replacing the 0th column from the first 2D array and the 1st and 3rd from the second 2D array.

We can now use these arrays to replace along the corresponding axes indexing as:

reconstruct_output[i, ..., l] = inputs[i, ..., l]

Getting:

reconstruct_output

array([[[[7, 9, 7, 2],
         [3, 4, 6, 1],
         [0, 4, 0, 3]]],


       [[[4, 7, 3, 8],
         [0, 5, 9, 8],
         [3, 6, 9, 5]]]])

Upvotes: 2

Related Questions