Łukasz Szymankiewicz
Łukasz Szymankiewicz

Reputation: 137

Conditional replacement of column in numpy array

I`m currently stuck on writing some script in numpy, which main goal is to be efficient (so, vectorization is mandatory).

Let`s assume 3-d array:

arr = [[[0, 0, 0, 0],
        [0, 0, 3, 4],
        [0, 0, 3, 0],
        [0, 2, 3, 0]],

       [[0, 0, 3, 0],
        [0, 0, 0, 0],
        [1, 0, 3, 0],
        [0, 0, 0, 0]],

       [[0, 2, 3, 4],
        [0, 0, 0, 0],
        [0, 0, 3, 4],
        [0, 0, 3, 0]],

       [[0, 0, 3, 4],
        [0, 0, 3, 4],
        [0, 0, 0, 0],
        [0, 0, 0, 0]]]

My goal is to set to dismiss every column which have more than one number other than zero. So, having above matrix the result should be something like:

filtered = [[[0, 0, 0, 0],
             [0, 0, 0, 4],
             [0, 0, 0, 0],
             [0, 2, 0, 0]],

            [[0, 0, 0, 0],
             [0, 0, 0, 0],
             [1, 0, 0, 0],
             [0, 0, 0, 0]],

            [[0, 2, 0, 0],
             [0, 0, 0, 0],
             [0, 0, 0, 0],
             [0, 0, 0, 0]],

            [[0, 0, 0, 0],
             [0, 0, 0, 0],
             [0, 0, 0, 0],
             [0, 0, 0, 0]]]

I`ve managed to work this around by set of np.count_nonzero, np.repeat and reshape:

indices = np.repeat(np.count_nonzero(a=arr, axis=1), repeats=4, axis=0).reshape(4, 4, 4)
result = indices * a

Which produces good results but looks like missing the point (there is a lot of cryptic matrix shape manipulation only to slice array properly). Furthermore, I`d wish this function to be flexible enough to work out with other axes too (for rows e.g.), resulting:

rows_fil = [[[0, 0, 0, 0],
             [0, 0, 0, 0],
             [0, 0, 3, 0],
             [0, 0, 0, 0]],

            [[0, 0, 3, 0],
             [0, 0, 0, 0],
             [0, 0, 0, 0],
             [0, 0, 0, 0]],

            [[0, 0, 0, 0],
             [0, 0, 0, 0],
             [0, 0, 0, 0],
             [0, 0, 3, 0]],

            [[0, 0, 0, 0],
             [0, 0, 0, 0],
             [0, 0, 0, 0],
             [0, 0, 0, 0]]

Is there any "numpy" way to achieve such flexible function?

Upvotes: 2

Views: 131

Answers (1)

Divakar
Divakar

Reputation: 221584

Here's a solution to cover a generic axis param -

def mask_nnzcount(a, axis):
    # a is input array
    mask = (a!=0).sum(axis=axis, keepdims=True)>1
    return np.where(mask, 0, a)

The trick really is at keepdims = True which allows us to have a generic solution.

With a 3D array, for your column-fill, that's with axis=1 and for row-fill it's axis=2.

For a generic ndarray, you might want to use axis=-2 for column-fill and axis=-1 for row-fill.

Alternatively, we could also use element-wise multiplication instead at the last step to get the output with a*(~mask). Or get an inverted mask i.e. say inv_mask = (a!=0).sum(axis=axis, keepdims=True)<=1 and then do a*inv_mask.

Upvotes: 1

Related Questions