Night Walker
Night Walker

Reputation: 21280

Select elements of numpy array via mask and preserving original dimensions

Hello I have following data

ids = np.concatenate([1.0 * np.ones(shape=(4, 9,)), 
                      2.0 * np.ones(shape=(4, 3,))], axis=1)

logits = np.random.normal(size=(4, 9 + 3, 256))

Now I want to get numpy array only of ids that have 1.0 and I want to get array of size (4,9, 256)

I tried logits[ids == 1.0, :] but I get (36, 256) How I can make slicing without connecting first two dimensions ?

Current dimensions are only example ones and I am looking for generic solution.

Upvotes: 4

Views: 780

Answers (1)

jakevdp
jakevdp

Reputation: 86328

Your question appears to assume that each row has the same number of nonzero entries; in that case you can solve your problem generally like this:

mask = (ids == 1)
num_per_row = mask.sum(1)

# same number of entries per row is required
assert np.all(num_per_row == num_per_row[0])  

result = logits[mask].reshape(logits.shape[0], num_per_row[0], logits.shape[2])

print(result.shape)
# (4, 9, 256)

Upvotes: 1

Related Questions