Reputation: 21280
Hello I have following data
ids = np.concatenate([1.0 * np.ones(shape=(4, 9,)),
2.0 * np.ones(shape=(4, 3,))], axis=1)
logits = np.random.normal(size=(4, 9 + 3, 256))
Now I want to get numpy array only of ids that have 1.0 and I want to get array of size (4,9, 256)
I tried logits[ids == 1.0, :]
but I get (36, 256)
How I can make slicing without connecting first two dimensions ?
Current dimensions are only example ones and I am looking for generic solution.
Upvotes: 4
Views: 780
Reputation: 86328
Your question appears to assume that each row has the same number of nonzero entries; in that case you can solve your problem generally like this:
mask = (ids == 1)
num_per_row = mask.sum(1)
# same number of entries per row is required
assert np.all(num_per_row == num_per_row[0])
result = logits[mask].reshape(logits.shape[0], num_per_row[0], logits.shape[2])
print(result.shape)
# (4, 9, 256)
Upvotes: 1