keren42
keren42

Reputation: 41

Change some of numpy array columns according to logical condition

I have a 2D NumPy array, and I want to perform the following operation:

For each column in the array, which is a series of non-decreasing values, replace this column with a column of the differences (that is, each entry is the difference between the two previous ones).

Every other column remains the same (except that first row is removed to fit to the differences columns dimension).

For example, in the matrix:

[ [1,1,1,2,3,4]
  [1,3,4,3,4,5]
  [1,7,3,4,2,7] ]

The differences matrix is:

[ [0,2,3,1,1,1]
  [0,4,-1,1,-1,2] ]

And thus the third and fifth columns which have decreasing values will remain the same, while other columns are replaced with the differences columns, resulting with:

[ [0,2,4,1,4,1]
  [0,4,3,1,2,2] ]

I tried something like this:

tempX = np.diff(X, axis = 0).transpose()
return np.where(tempX >= 0, tempX, X[1:].transpose())

But the condition in np.where is performed element-wise and not for each column (or row).

How can I change the condition so it will work? Is there a more efficient way to implement this?

Upvotes: 2

Views: 382

Answers (2)

quantummind
quantummind

Reputation: 2136

You can try it so:

b = a[1:] - a[:-1]
decrease = numpy.where(numpy.min(b, axis=0)<0)
b[:,decrease] = a[1:, decrease]

You can also do that in one expression:

numpy.where(numpy.min(a[1:]-a[:-1],axis=0)>=0, a[1:]-a[:-1], a[1:])

Upvotes: 1

Divakar
Divakar

Reputation: 221774

You could use boolean-indexing -

# Get the differentiation along first axis
diffs = np.diff(a,axis=0)

# Mask of invalid ones
mask = (diffs<0).any(0)

# Use the mask to set the invalid ones to the original elements
diffs[:,mask] = a[1:,mask]

Sample run -

In [141]: a
Out[141]: 
array([[1, 1, 1, 2, 3, 4],
       [1, 3, 4, 3, 4, 5],
       [1, 7, 3, 4, 2, 7]])

In [142]: diffs = np.diff(a,axis=0)
     ...: mask = (diffs<0).any(0)
     ...: diffs[:,mask] = a[1:,mask]
     ...: 

In [143]: diffs
Out[143]: 
array([[0, 2, 4, 1, 4, 1],
       [0, 4, 3, 1, 2, 2]])

Upvotes: 0

Related Questions