ysearka
ysearka

Reputation: 3855

Perform group operation on 2D numpy array

I have a 2D numpy array (in fact a similarity matrix) on which I need to compute average by blocks. For instance with the following matrix:

sima = np.array([[1,0.8,0.7,0.3,0.1,0.5],
                 [0.8,1,0.1,0.5,0.2,0.5],
                 [0.7,0.1,1,0.1,0.3,0.9],
                 [0.3,0.5,0.1,1,0.8,0.5],
                 [0.1,0.2,0.3,0.8,1,0.5],
                 [0.5,0.5,0.9,0.5,0.5,1]])

And labels vector :

labels = np.array([1,1,1,2,2,3])

This means that the first three rows of the matrix (as well as columns columns since a similarity matrix is symmetric) correspond to the cluster 1, the next 2 correspond to the cluster 2, and the last correspond to the cluster 3.

I need to compute the average of the blocks in sima correpsonding to the labels in labels. Yielding the following output:

0.69 0.25 0.63 
0.25 0.90 0.50 
0.63 0.50 1.00

So far, I have a working solution using a double loop on labels and masked arrays:

labels_matrix = np.tile(np.array(labels), (len(labels), 1))
output = pd.DataFrame(np.zeros(shape = (3,3)))

for i in range(3):
  for j in range(3):
    mask = (labels_matrix != j+1) | (labels_matrix.T != i+1)
    output.loc[i,j] = np.mean(np.mean(np.ma.array(sima, mask = mask)))

This code yields the correct output, but my actual matrix is 50kx50k, and this code takes forever to compute. How could I make it faster?

Note: I need a different order of magnitude in speed, so I expect using tricks like the symmetry of the similarity matrix won't be enough.

Upvotes: 1

Views: 83

Answers (1)

Divakar
Divakar

Reputation: 221574

For sorted labels, we can use np.add.reduceat -

In [62]: idx = np.flatnonzero(np.r_[True,labels[:-1] != labels[1:],True])

In [63]: c = np.diff(idx)

In [64]: sums = np.add.reduceat(np.add.reduceat(sima,idx[:-1],axis=0),idx[:-1],axis=1)

In [65]: sums/(c[:,None]*c)
Out[65]: 
array([[0.68888889, 0.25      , 0.63333333],
       [0.25      , 0.9       , 0.5       ],
       [0.63333333, 0.5       , 1.        ]])

Upvotes: 2

Related Questions