Reputation: 43569
I have a mask
with a size of torch.Size([20, 1, 199])
and a tensor, reconstruct_output
and inputs
both with a size of torch.Size([20, 1, 161, 199])
.
I want to set reconstruct_output
to inputs
where the mask
is 0
. I tried:
reconstruct_output[mask == 0] = inputs[mask == 0]
But I get an error:
IndexError: The shape of the mask [20, 1, 199] at index 2 does not match the shape of the indexed tensor [20, 1, 161, 199] at index 2
Upvotes: 1
Views: 2460
Reputation: 88275
We can use advanced indexing
here. To obtain the indexing arrays which we want to use to index both reconstruct_output
and inputs
, we need the indices along its axes where m==0
. For that we can use np.where
, and use the resulting indices to update reconstruct_output
as:
m = mask == 0
i, _, l = np.where(m)
reconstruct_output[i, ..., l] = inputs[i, ..., l]
Here's a small example which I've checked with:
mask = np.random.randint(0,3, (2, 1, 4))
reconstruct_output = np.random.randint(0,10, (2, 1, 3, 4))
inputs = np.random.randint(0,10, (2, 1, 3, 4))
Giving for instance:
print(reconstruct_output)
array([[[[8, 9, 7, 2],
[5, 4, 6, 1],
[1, 4, 0, 3]]],
[[[4, 3, 3, 4],
[0, 9, 9, 7],
[3, 4, 9, 3]]]])
print(inputs)
array([[[[7, 3, 9, 8],
[3, 1, 0, 8],
[0, 5, 4, 8]]],
[[[3, 7, 5, 8],
[2, 5, 3, 8],
[3, 6, 7, 5]]]])
And the mask
:
print(mask)
array([[[0, 1, 2, 1]],
[[1, 0, 1, 0]]])
By using np.where
to find the indices where there are zeroes in mask
we get:
m = mask == 0
i, _, l = np.where(m)
i
# array([0, 1, 1])
l
# array([0, 1, 3])
Hence we'll be replacing the 0th column from the first 2D array and the 1st and 3rd from the second 2D array.
We can now use these arrays to replace along the corresponding axes indexing as:
reconstruct_output[i, ..., l] = inputs[i, ..., l]
Getting:
reconstruct_output
array([[[[7, 9, 7, 2],
[3, 4, 6, 1],
[0, 4, 0, 3]]],
[[[4, 7, 3, 8],
[0, 5, 9, 8],
[3, 6, 9, 5]]]])
Upvotes: 2