user109387
user109387

Reputation: 704

Eliminating array rows that fail to meet two conditions

Consider arrays m and n. Both have identical shapes. m and n always have an even number of columns, and I have added spaces to emphasize that each array row is made up of PAIRS of elements.

import numpy as np

m = np.array([[5,3,  6,7,  3,8],
              [5,4,  5,1,  4,5],
              [5,4,  2,4,  4,6],
              [2,2,  2,3,  8,5],
              [2,7,  8,7,  1,2],
              [2,7,  8,7,  3,2],
              [2,7,  8,7,  4,2]]) 

n = np.array([[1,3,  6,7,  1,12],
              [2,4,  2,9,  4,9],
              [1,5,  5,12, 9,1],
              [5,4,  5,6,  9,5],
              [5,4,  1,4,  1,5],
              [5,4,  1,7,  9,5],
              [5,4,  1,5  4,11]]) 

My overall objective is to eliminate rows from m that fail to meet two TEST conditions:

TEST 1: we keep rows of m only if every pair had a common element with some other pair in the row. This has already been answered very capably by Quang Hoang on Nov 13, 2020, (https://stackoverflow.com/a/64814379/11188140 ). I include it here because the code will be helpful with TEST 2 to follow.

Using this code, the 1st row of m is rejected because pair (6,7) has no common element in the other row pairs. The 4th row of m also is rejected because its last pair has no common element in the other row pairs. The earlier code, that works perfectly, and its output are as follows:

a = m.reshape(m.shape[0],-1,2)
mask = ~np.eye(a.shape[1], dtype=bool)[...,None]

is_valid = (((a[...,None,:]==a[:,None,...])&mask).any(axis=(-1,-2))
            |((a[...,None,:]==a[:,None,:,::-1])&mask).any(axis=(-1,-2))
           ).all(-1)

m[is_valid]

The output, m, is:

              [[5,4,  5,1,  4,5],
               [5,4,  2,4,  4,6],
               [2,7,  8,7,  1,2],
               [2,7,  8,7,  3,2],
               [2,7,  8,7,  4,2]]

TEST 2: we keep rows of m only if the row in array n (ie: same index) has the SAME PAIR MATCHINGS as in m. An n row may have 'extra' pair matchings as well, but it MUST include the pair matchings that m has.

Three examples make this clear:

a) 5th row of m and n:  '[2,7,  8,7,  1,2]' and '[5,4,  1,4,  1,5]'
   In m, the 1st & 2nd pairs share an element, and the 1st and & 3rd pairs share an element.
   n has both of these matchings, so **TEST 2 PASSES**, and we keep this m row. 
   The fact that 2nd & 3rd pairs of n also share an element is immaterial.                                                                          
                                                                                                
b) 6th row of m and n:  '[2,7,  8,7,  3,2]' and '[5,4,  1,7,  9,5]'
   In m, the 1st & 2nd pairs share an element, and the 1st & 3rd pairs share an element.
   n DOES NOT have BOTH of these matchings, so **TEST 2 FAILS**.  The row may be eliminated from m (and from n too if that's needed)

c) 3rd row of m and n:  '[5,4,  2,4,  4,6]' and '[1,5,  5,12, 9,1]'
   In m, the 1st & 2nd pairs share an element, the 2nd & 3rd pairs share an element, and the 1st & 3rd pairs share an element.
   n lacks the 2nd & 3rd pair sharing that m has, so **TEST 2 FAILS**.  The row may be eliminated from m (and from n too if that's needed)

The final output, m, after passing both tests, is:

              [[5,4,  5,1,  4,5],
               [2,7,  8,7,  1,2],
               [2,7,  8,7,  4,2]]

Upvotes: 1

Views: 166

Answers (1)

Akshay Sehgal
Akshay Sehgal

Reputation: 19322

Here is a vectorized solution for the above 2 tests you mention. I have modified the code a bit for test1 as well so that it flows nicely into test2.

import numpy as np

m = np.array([[5,3,  6,7,  3,7],
              [5,4,  5,1,  4,5],
              [5,4,  2,4,  4,6],
              [2,2,  2,3,  8,5],
              [2,7,  8,7,  1,2],
              [2,7,  8,7,  3,2],
              [2,7,  8,7,  4,2]]) 

n = np.array([[1,3,  6,7,  1,12],
              [2,4,  2,9,  4,9],
              [1,5,  5,12, 9,1],
              [5,4,  5,6,  9,5],
              [5,4,  1,4,  1,5],
              [5,4,  1,7,  9,5],
              [5,4,  1,5,  4,11]])


mm = m.reshape(m.shape[0],-1,2) #(7,3,2)
nn = n.reshape(m.shape[0],-1,2) #(7,3,2)

#broadcast((7,3,1,2,1), (7,1,3,1,2)) -> (7,3,3,2,2) -> (7,3,3)
matches_m = np.any(mm[:,:,None,:,None] == mm[:,None,:,None,:], axis=(-1,-2))  #(7,3,3)
matches_n = np.any(nn[:,:,None,:,None] == nn[:,None,:,None,:], axis=(-1,-2))  #(7,3,3)

mask = ~np.eye(mm.shape[1], dtype=bool)[None,:]  #(1,3,3)

is_valid1 = np.all(np.any(matches_m&mask, axis=-1), axis=-1)  #(7,)
is_valid2 = np.all(~(matches_m^matches_n)|matches_n, axis=(-1,-2)) #(7,)

m[is_valid1 & is_valid2]
array([[5, 4, 5, 1, 4, 5],
       [2, 7, 8, 7, 1, 2],
       [2, 7, 8, 7, 4, 2]])

Explanation -

TEST 1

  1. First step is to reshape n and m to (7,3,2) so that we can broadcast over axis=1
  2. Next we need to get a comparison between the elements of the last axis for the cross product of elements for each row. So the expected output would be (7 rows, 3X3 cross product between elements).
  3. But for the comparison, I would also have to broadcast over the last axis (2x2). Meaning I would need a (7,3,3,2,2). This can be done by doing broadcasting over 2 arrays of (7,3,1,2,1) and (7,1,3,1,2).
  4. Finally an any() over the last 2 axes will give you a (7,3,3) where, for each of the 7 rows, you are comparing each element to the other, and returning True if any of them has a common element. THIS IS ALSO THE ARRAY THAT CONTAINS THE MATCHES AND WILL BE IMPORTANT FOR TEST2!
  5. Next, since the same element comparison will always give True, you want to ignore the diagonals, so create a mask for it.
  6. Using the mask, get which of the elements have at least 1 match with other elements except itself, and that solves TEST1.

TEST 2

  1. Apply the same first 4 steps to get the matches matrix (7,3,3) for n.
  2. Here, any match that exists in m MUST exist in n, but any match that exists in n is not needed in m. The required logic is -
a = np.array([True, False, True, False])
b = np.array([True, True, False, False])

~(a^b)|b
#array([ True,  True, False,  True])
  1. Applying this over the 2 (7,3,3) matches, if even a single false exists in the respective (3,3), it indicates that matches between pairs in m are not reflecting in n. So you get False there. This over axis = -1, -2 results in a (7,)

  2. This gives you the TEST2.

Hope this solves your problem.

Upvotes: 1

Related Questions