Reputation: 715
I have two 2D Numpy arrays, let's say:
A1 = [[5,3]
[4,6]]
and
A2 = [[7,9]
[5,0]]
I want to be able to select an index of A1, let's say, [1][0]
, which gives me a selected value of 4.
Now, I want to subtract A1
element-wise from A2
wherever both corresponding elements are nonzero. However, where the selected value is less than the A1 element I want to subtract that selected value instead of the A1 element.
In this case, that would mean my final result is:
A3 = [[3,6]
[1,0]]
This is because 4 is less than 5, so I subtract 4 from A2[0][0]
. 4 is not greater than 3, so I subtract 3 from A2[0][1]
. 4 is equal to 4, so I subtract 4 from A2[1][0]
. The final value in A2 is not nonzero, so I leave it alone.
Sorry, I don't have a code attempt because I simply don't know how to do this.
Upvotes: 2
Views: 1804
Reputation: 55469
One way to do this is using numpy.where
, which lets us select between two arrays depending on a condition. We first build a mask array that is True where both A1
& A2
are non-zero, and False otherwise. I'll call your selected value val
. Your condition of using val
wherever the A1
value is greater than val
can be done by taking the minimum of A1
and val
. Here's the procedure, step by step, so we can see what's going on.
import numpy as np
A1 = np.array([[5, 3], [4, 6]])
A2 = np.array([[7, 9], [5, 0]])
print(A1, '\n')
print(A2, '\n')
# Selected indices of A1
row, col = 1, 0
val = A1[row, col]
print(val)
# Find where both A1 & A2 are nonzero
mask = (A1 != 0) & (A2 != 0)
print(mask, '\n')
# Replace values in A1 that are greater than val
A1a = np.minimum(A1, val)
print(A1a, '\n')
# Only do the subtraction where both A1 & A2 are nonzero,
# otherwise copy the A2 value
A3 = np.where(mask, A2 - A1a, A2)
print(A3, '\n')
output
[[5 3]
[4 6]]
[[7 9]
[5 0]]
4
[[ True True]
[ True False]]
[[4 3]
[4 4]]
[[3 6]
[1 0]]
And in one line:
A3 = np.where((A1 != 0) & (A2 != 0), A2 - np.minimum(A1, val), A2)
A closely related version is:
A3 = A2 - np.where((A1 != 0) & (A2 != 0), np.minimum(A1, val), 0)
Here's yet another version, courtesy of Andras Deak, which is similar to np.where
, but should be faster when mask
has a large number of False entries:
A3 = A2.copy()
A3[mask] -= np.minimum(A1, val)[mask]
Upvotes: 2