RRR
RRR

Reputation: 63

create a function to compute all pairwise cosine similarity of the row vectors in a 2-D matrix using only numpy

For example, given matrix

array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14]])

it should return

array([[1.        , 0.91465912, 0.87845859],
       [0.91465912, 1.        , 0.99663684],
       [0.87845859, 0.99663684, 1.        ]])

where the (i, j) entry of the result is the cosine similarity between the row vector arr[i] and the row vector arr[j]: cos_sim[i, j] == CosSim(arr[i], arr[j]).

As usual, the cosine similarity between two vectors 𝑥,𝑦 is defined as: enter image description here

This function should return a np.ndarray of shape (arr.shape[0], arr.shape[0])

Upvotes: 3

Views: 3245

Answers (2)

ggaurav
ggaurav

Reputation: 1804

a

array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14]])

Using the second formula, say pq
enter image description here

p = a / np.linalg.norm(a, 2, axis=1).reshape(-1,1)
p
array([[0.        , 0.18257419, 0.36514837, 0.54772256, 0.73029674],
       [0.31311215, 0.37573457, 0.438357  , 0.50097943, 0.56360186],
       [0.37011661, 0.40712827, 0.44413993, 0.48115159, 0.51816325]])

Note that the norm has to be calculated row wise. And so, we have above axis=1. Also, norms would be rank 1 vector. So, to convert into a shape (3,1) in this case, reshape would be required. Also, the above formula is for vector, when you apply to matrix, "the transpose part would be come second".

Now in this case, q is nothing but p iteslf. So, cosine similarity would be

np.dot(p, p.T)
array([[1.        , 0.91465912, 0.87845859],
       [0.91465912, 1.        , 0.99663684],
       [0.87845859, 0.99663684, 1.        ]])

Upvotes: 2

Quang Hoang
Quang Hoang

Reputation: 150735

Try:

from scipy.spatial.distance import cdist

1 - cdist(a, a, metric='cosine')

Output:

array([[1.        , 0.91465912, 0.87845859],
       [0.91465912, 1.        , 0.99663684],
       [0.87845859, 0.99663684, 1.        ]])

Upvotes: 2

Related Questions