Merlin
Merlin

Reputation: 25629

How to apply mask from array to another matrix in numpy

How you I apply a mask in numpy to get this output?

ar2 = np.arange(1,26)[::-1].reshape([5,5]).T
ar3 = np.array([1,1,-1,-1,1])
print ar2, '\n\n',  ar3

[[25 20 15 10  5]
 [24 19 14  9  4]
 [23 18 13  8  3]
 [22 17 12  7  2]
 [21 16 11  6  1]] 

[ 1  1 -1 -1  1]

--apply where ar3 = 1: ar2/ar2[:,0][:, np.newaxis]

--apply where ar3 = -1: ar2/ar2[:,4][:, np.newaxis]

The result I am after is:

[[1 0 0 0 0]
 [1 0 0 0 0]
 [ 7  6  4  2  1]
 [11  8  6  3  1]
 [1 0 0 0 0]]

I have tried np.where()

Upvotes: 1

Views: 313

Answers (2)

MSeifert
MSeifert

Reputation: 152587

I don't see why np.where shouldn't work here:

>>> np.where((ar3==1)[:, None], 
...          ar2 // ar2[:, [0]],  # where condition is True, divide by first column
...          ar2 // ar2[:, [4]])  # where condition is False, divide by last column
array([[ 1,  0,  0,  0,  0],
       [ 1,  0,  0,  0,  0],
       [ 7,  6,  4,  2,  1],
       [11,  8,  6,  3,  1],
       [ 1,  0,  0,  0,  0]])

I'm using Python 3 that's why I used // (floor division) instead of regular division (/) otherwise the result would contain floats.

This computes the arrays eagerly, so it evaluates ar2 // ar2[:, [0]] and ar2 // ar2[:, [4]] for all values. Effectively holding 3 arrays of the size of ar2 in memory (the result and the two temporaries). If you want it more memory-efficient you need to do apply the mask before doing the operation:

>>> res = np.empty_like(ar2)
>>> mask = ar3 == 1
>>> res[mask] = ar2[mask] // ar2[mask][:, [0]]
>>> res[~mask] = ar2[~mask] // ar2[~mask][:, [4]]
>>> res
array([[ 1,  0,  0,  0,  0],
       [ 1,  0,  0,  0,  0],
       [ 7,  6,  4,  2,  1],
       [11,  8,  6,  3,  1],
       [ 1,  0,  0,  0,  0]])

This computes only the necessary values which uses less memory (and is probably faster too).

Upvotes: 3

cs95
cs95

Reputation: 402263

Not the most elegant, but here's what I could think of.

m = ar3 == -1
a = (ar2 // ar2[:, [0]])
a[m] = (ar2 // ar2[:, [4]])[m]

print(a)
array([[ 1,  0,  0,  0,  0],
       [ 1,  0,  0,  0,  0],
       [ 7,  6,  4,  2,  1],
       [11,  8,  6,  3,  1],
       [ 1,  0,  0,  0,  0]], dtype=int32)

Upvotes: 2

Related Questions