ironv
ironv

Reputation: 1058

identifying sub-arrays in numpy

I have two two dimensional arrays a and b (#columns of a <= #columns in b). I would like to find an efficient way of matching a row in array a to a contiguous part of a row in array b.

a = np.array([[ 25,  28],
              [ 84,  97],
              [105,  24],
              [ 28, 900]])

b = np.array([[ 25,  28,  84,  97],
              [ 22,  25,  28, 900],
              [ 11,  12, 105,  24]])

The output should be np.array([[0,0], [0,1], [1,0], [2,2], [3,1]]). Row 0 in array a matches Row 0 in array b (first two positions). Row 1 in array a matches row 0 in array b (third and fourth positions).

Upvotes: 1

Views: 353

Answers (2)

Divakar
Divakar

Reputation: 221704

We can leverage np.lib.stride_tricks.as_strided based scikit-image's view_as_windows for efficient patch extraction, and then compare those patches against each row off a, all of it in a vectorized manner. Then, get the matching indices with np.argwhere -

# a and b from posted question
In [325]: from skimage.util.shape import view_as_windows

In [428]: w = view_as_windows(b,(1,a.shape[1]))

In [429]: np.argwhere((w == a).all(-1).any(-2))[:,::-1]
Out[429]: 
array([[0, 0],
       [1, 0],
       [0, 1],
       [3, 1],
       [2, 2]])

Alternatively, we could get the indices by the order of rows in a by pushing forward the first axis of a while performing broadcasted comparisons -

In [444]: np.argwhere((w[:,:,0] == a[:,None,None,:]).all(-1).any(-1))
Out[444]: 
array([[0, 0],
       [0, 1],
       [1, 0],
       [2, 2],
       [3, 1]])

Upvotes: 4

rayryeng
rayryeng

Reputation: 104555

Another way I can think of is to loop over each row in a and perform a 2D correlation between the b which you can consider as a 2D signal a row in a. We would find the results which are equal to the sum of squares of all values in a. If we subtract our correlation result with this sum of squares, we would find matches with a zero result. Any rows that give you a 0 result would mean that the subarray was found in that row. If you are using floating-point numbers for example, you may want to compare with some small threshold that is just above 0.

If you can use SciPy, the scipy.signal.correlate2d method is what I had in mind.

import numpy as np
from scipy.signal import correlate2d

a = np.array([[ 25,  28],
              [ 84,  97],
              [105,  24]])

b = np.array([[ 25,  28,  84,  97],
              [ 22,  25,  28, 900],
              [ 11,  12, 105,  24]])

EPS = 1e-8
result = []
for (i, row) in enumerate(a):
    out = correlate2d(b, row[None,:], mode='valid') - np.square(row).sum()
    locs = np.where(np.abs(out) <= EPS)[0]

    unique_rows = np.unique(locs)
    for res in unique_rows:
        result.append((i, res))

We get:

In [32]: result
Out[32]: [(0, 0), (0, 1), (1, 0), (2, 2)]

The time complexity of this could be better, especially since we're looping over each row of a to find any subarrays in b.

Upvotes: 1

Related Questions