Travis Black
Travis Black

Reputation: 715

How to subtract up to a certain value where both arrays are nonzero?

I have two 2D Numpy arrays, let's say:

A1 = [[5,3]
      [4,6]]

and

A2 = [[7,9]
      [5,0]]

I want to be able to select an index of A1, let's say, [1][0], which gives me a selected value of 4.

Now, I want to subtract A1 element-wise from A2 wherever both corresponding elements are nonzero. However, where the selected value is less than the A1 element I want to subtract that selected value instead of the A1 element.

In this case, that would mean my final result is:

A3 = [[3,6]
      [1,0]]

This is because 4 is less than 5, so I subtract 4 from A2[0][0]. 4 is not greater than 3, so I subtract 3 from A2[0][1]. 4 is equal to 4, so I subtract 4 from A2[1][0]. The final value in A2 is not nonzero, so I leave it alone.

Sorry, I don't have a code attempt because I simply don't know how to do this.

Upvotes: 2

Views: 1804

Answers (1)

PM 2Ring
PM 2Ring

Reputation: 55469

One way to do this is using numpy.where, which lets us select between two arrays depending on a condition. We first build a mask array that is True where both A1 & A2 are non-zero, and False otherwise. I'll call your selected value val. Your condition of using val wherever the A1 value is greater than val can be done by taking the minimum of A1 and val. Here's the procedure, step by step, so we can see what's going on.

import numpy as np

A1 = np.array([[5, 3], [4, 6]])
A2 = np.array([[7, 9], [5, 0]])
print(A1, '\n')
print(A2, '\n')

# Selected indices of A1
row, col = 1, 0    
val = A1[row, col]
print(val)

# Find where both A1 & A2 are nonzero
mask = (A1 != 0) & (A2 != 0)
print(mask, '\n')

# Replace values in A1 that are greater than val
A1a = np.minimum(A1, val)
print(A1a, '\n')

# Only do the subtraction where both A1 & A2 are nonzero,
# otherwise copy the A2 value
A3 = np.where(mask, A2 - A1a, A2)
print(A3, '\n')

output

[[5 3]
 [4 6]] 

[[7 9]
 [5 0]] 

4
[[ True  True]
 [ True False]] 

[[4 3]
 [4 4]] 

[[3 6]
 [1 0]] 

And in one line:

A3 = np.where((A1 != 0) & (A2 != 0), A2 - np.minimum(A1, val), A2)

A closely related version is:

A3 = A2 - np.where((A1 != 0) & (A2 != 0), np.minimum(A1, val), 0)

Here's yet another version, courtesy of Andras Deak, which is similar to np.where, but should be faster when mask has a large number of False entries:

A3 = A2.copy()
A3[mask] -= np.minimum(A1, val)[mask]

Upvotes: 2

Related Questions