Reputation: 137
I`m currently stuck on writing some script in numpy, which main goal is to be efficient (so, vectorization is mandatory).
Let`s assume 3-d array:
arr = [[[0, 0, 0, 0],
[0, 0, 3, 4],
[0, 0, 3, 0],
[0, 2, 3, 0]],
[[0, 0, 3, 0],
[0, 0, 0, 0],
[1, 0, 3, 0],
[0, 0, 0, 0]],
[[0, 2, 3, 4],
[0, 0, 0, 0],
[0, 0, 3, 4],
[0, 0, 3, 0]],
[[0, 0, 3, 4],
[0, 0, 3, 4],
[0, 0, 0, 0],
[0, 0, 0, 0]]]
My goal is to set to dismiss every column which have more than one number other than zero. So, having above matrix the result should be something like:
filtered = [[[0, 0, 0, 0],
[0, 0, 0, 4],
[0, 0, 0, 0],
[0, 2, 0, 0]],
[[0, 0, 0, 0],
[0, 0, 0, 0],
[1, 0, 0, 0],
[0, 0, 0, 0]],
[[0, 2, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]],
[[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]]]
I`ve managed to work this around by set of np.count_nonzero, np.repeat and reshape:
indices = np.repeat(np.count_nonzero(a=arr, axis=1), repeats=4, axis=0).reshape(4, 4, 4)
result = indices * a
Which produces good results but looks like missing the point (there is a lot of cryptic matrix shape manipulation only to slice array properly). Furthermore, I`d wish this function to be flexible enough to work out with other axes too (for rows e.g.), resulting:
rows_fil = [[[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 3, 0],
[0, 0, 0, 0]],
[[0, 0, 3, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]],
[[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 3, 0]],
[[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]]
Is there any "numpy" way to achieve such flexible function?
Upvotes: 2
Views: 131
Reputation: 221584
Here's a solution to cover a generic axis param -
def mask_nnzcount(a, axis):
# a is input array
mask = (a!=0).sum(axis=axis, keepdims=True)>1
return np.where(mask, 0, a)
The trick really is at keepdims = True
which allows us to have a generic solution.
With a 3D array, for your column-fill, that's with axis=1
and for row-fill it's axis=2
.
For a generic ndarray, you might want to use axis=-2
for column-fill and axis=-1
for row-fill.
Alternatively, we could also use element-wise multiplication instead at the last step to get the output with a*(~mask)
. Or get an inverted mask i.e. say inv_mask = (a!=0).sum(axis=axis, keepdims=True)<=1
and then do a*inv_mask
.
Upvotes: 1